Skip to content

Commit 8e17473

Browse files
vthumbe1503pre-commit-ci[bot]yaox12timmoon10
committed
parent de9ef2f
author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262513 +0000 parent de9ef2f author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262476 +0000 parent de9ef2f author Varun Thumbe <[email protected]> 1757373536 +0000 committer Varun Thumbe <[email protected]> 1758262304 +0000 merge conflict Signed-off-by: Varun Thumbe <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe <[email protected]> FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (NVIDIA#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang <[email protected]> * Slightly refactor Signed-off-by: Ming Huang <[email protected]> * Adding documents of new args. Signed-off-by: Ming Huang <[email protected]> * Adding unit-tests. Signed-off-by: Ming Huang <[email protected]> * Adding license. Signed-off-by: Ming Huang <[email protected]> * Move unit-tests to L1. Signed-off-by: Ming Huang <[email protected]> * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang <[email protected]> * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang <[email protected]> * Adopt the feedback from code-review. Signed-off-by: Ming Huang <[email protected]> * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] Delay MeshResource validation until first usage (NVIDIA#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (NVIDIA#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> [JAX] Add amax input to DBiasQuantizePrimitive and FFI (NVIDIA#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen <[email protected]> * fix sharding rule Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (NVIDIA#2121) Signed-off-by: Kshitij Lakhani <[email protected]> Temporarily remove comm_gemm tests (NVIDIA#2133) Signed-off-by: Vladimir Cherepanov <[email protected]> [PyTorch] Disable determinism for sm100 (NVIDIA#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <[email protected]> * fix remaining CI failures Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <[email protected]> * revert more changes Signed-off-by: Charlene Yang <[email protected]> * remove sm100 from determinism table Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (NVIDIA#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * apply tims suggestions Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: Jan Bielak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (NVIDIA#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> build: pull cached wheels (NVIDIA#2127) * build: pull cached wheels Signed-off-by: oliver könig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> [Common] Add checks to CUDA kernel launch and CUDA API calls (NVIDIA#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> [PyTorch] Support bf16+fp8 cudagraph (NVIDIA#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Dropout with 8-bit RNG (NVIDIA#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon <[email protected]> * Avoid ambiguous types Signed-off-by: Tim Moon <[email protected]> * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon <[email protected]> * Expand error message Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon <[email protected]> * Fix linter warning Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (NVIDIA#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]> Fix CI failures for UB overlap changes (NVIDIA#2149) Signed-off-by: djns99 <[email protected]> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (NVIDIA#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (NVIDIA#2119) * add noop to comp amax Signed-off-by: zhongboz <[email protected]> * fix for fp8 blockwise recipe Signed-off-by: zhongboz <[email protected]> * resolve comments Signed-off-by: zhongboz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> [PyTorch] fix cross entropy vanishing gradients (NVIDIA#2139) * fix cross entropy Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper <[email protected]> * fix comments Signed-off-by: Casper <[email protected]> * fix: few more style issues Signed-off-by: Casper <[email protected]> * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Fix bug when enabling --overlap-grad-reduce in mcore (NVIDIA#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu <[email protected]> * format Signed-off-by: Hongbin Liu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu <[email protected]> Co-authored-by: Hongbin Liu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix CUDA version in setup.py (NVIDIA#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov <[email protected]> * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov <[email protected]> * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Tim Moon <[email protected]> [JAX] NoScaleTensor wrapper for non-quantized data (NVIDIA#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> * Lint Signed-off-by: Jeremy Berchtold <[email protected]> * Fix comments Signed-off-by: Jeremy Berchtold <[email protected]> * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold <[email protected]> * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold <[email protected]> * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold <[email protected]> * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold <[email protected]> * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix tests Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> [JAX] Fix GroupedScaledTensor creation with keyword arg (NVIDIA#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen <[email protected]> Fixing few issues with multi-process launching. (NVIDIA#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen <[email protected]> Update list of authorized CI users (NVIDIA#2152) Signed-off-by: Tim Moon <[email protected]> Fused RoPE with combined QKV input. (NVIDIA#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df100. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Optimize kernels Signed-off-by: Vasudevan Rengasamy <[email protected]> * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <[email protected]> * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao <[email protected]> * fix lint Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Vasudevan Rengasamy <[email protected]> Signed-off-by: Xin Yao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent de9ef2f commit 8e17473

File tree

9 files changed

+904
-63
lines changed

9 files changed

+904
-63
lines changed

tests/pytorch/test_fused_rope.py

Lines changed: 141 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# See LICENSE for license information.
4-
from typing import Callable, Tuple, Union
4+
from typing import Callable, Tuple, Union, List
55
import math
66
import torch
77
import pytest
88
from transformer_engine.pytorch.attention.rope import (
99
RotaryPositionEmbedding,
1010
apply_rotary_pos_emb,
11+
apply_fused_qkv_rotary_pos_emb,
1112
)
1213

1314

1415
# Gradient is a broadcasted scalar
15-
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
16-
return output.sum() * 2
16+
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
17+
if isinstance(output, List):
18+
return sum(t.sum() * 2 for t in output)
19+
else:
20+
return output.sum() * 2
1721

1822

1923
# Gradient is a full tensor
20-
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
21-
t = torch.ones_like(output)
22-
return torch.sum(output * t)
24+
def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
25+
if isinstance(output, List):
26+
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
27+
else:
28+
t = torch.ones_like(output)
29+
return torch.sum(output * t)
2330

2431

2532
@pytest.mark.parametrize("start_positions", [True, False])
@@ -238,3 +245,131 @@ def test_fused_rope_thd(
238245
torch.testing.assert_close(grad_fused, grad_unfused)
239246

240247
assert output_fused.is_contiguous()
248+
249+
250+
@pytest.mark.parametrize("start_positions", [True, False])
251+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
252+
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
253+
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
254+
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
255+
@pytest.mark.parametrize("margin", [0, 10])
256+
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
257+
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
258+
@pytest.mark.parametrize("cp_size", [1, 2])
259+
@pytest.mark.parametrize("interleaved", [True, False])
260+
def test_fused_qkv_rope(
261+
dtype: torch.dtype,
262+
seq_length: int,
263+
hidden_size: int,
264+
rotary_percent: float,
265+
margin: int,
266+
tensor_format: str,
267+
loss_func: Callable,
268+
cp_size: int,
269+
interleaved: bool,
270+
start_positions: bool,
271+
) -> None:
272+
if margin == 0 and start_positions == True:
273+
# This makes sure that the `start_positions` offsets being applied
274+
# are with the maximum length of the rope embeddings.
275+
pytest.skip("Skipping test with margin=0 and start_positions=True")
276+
277+
if start_positions == True and cp_size > 1:
278+
# `start_positions` is only supported for `cp_size=1` and inference.
279+
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
280+
281+
if seq_length - margin < 0:
282+
pytest.skip("Skipping test with seq_length - margin < 0")
283+
284+
device = torch.device("cuda:0")
285+
batch_size, head_num = 2, 64
286+
287+
t = torch.rand(
288+
(seq_length - margin, batch_size, head_num, hidden_size * 6),
289+
dtype=dtype,
290+
device=device,
291+
)
292+
293+
# Get arbitrary offsets to be used with RoPE for all the sequences
294+
start_positions = (
295+
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
296+
if start_positions
297+
else None
298+
)
299+
300+
if tensor_format == "bshd":
301+
t = t.transpose(0, 1).contiguous()
302+
t.requires_grad = True
303+
304+
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
305+
emb_q = rotary_pos_emb_q(seq_length * cp_size)
306+
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
307+
emb_k = rotary_pos_emb_k(seq_length * cp_size)
308+
309+
for cp_rank in range(cp_size):
310+
# unfused
311+
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
312+
# for more accurate comparison
313+
314+
t_clone = t.clone()
315+
(query, key, value) = torch.split(
316+
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
317+
)
318+
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
319+
320+
query_unfused = apply_rotary_pos_emb(
321+
query,
322+
emb_q,
323+
tensor_format=tensor_format,
324+
start_positions=start_positions,
325+
interleaved=interleaved,
326+
fused=True,
327+
cp_size=cp_size,
328+
cp_rank=cp_rank,
329+
).to(dtype)
330+
331+
key_unfused = apply_rotary_pos_emb(
332+
key,
333+
emb_k,
334+
tensor_format=tensor_format,
335+
start_positions=start_positions,
336+
interleaved=interleaved,
337+
fused=True,
338+
cp_size=cp_size,
339+
cp_rank=cp_rank,
340+
).to(dtype)
341+
342+
value_unfused = value
343+
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
344+
345+
if not isinstance(start_positions, torch.Tensor):
346+
loss_unfused.backward()
347+
grad_unfused = t.grad.detach().clone()
348+
349+
t.grad = None
350+
351+
# fused
352+
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
353+
t,
354+
emb_q,
355+
emb_k,
356+
tensor_format=tensor_format,
357+
start_positions=start_positions,
358+
interleaved=interleaved,
359+
cp_size=cp_size,
360+
cp_rank=cp_rank,
361+
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
362+
)
363+
loss_fused = loss_func([query_fused, key_fused, value_fused])
364+
365+
if not isinstance(start_positions, torch.Tensor):
366+
loss_fused.backward()
367+
grad_fused = t.grad.detach().clone()
368+
t.grad = None
369+
370+
torch.testing.assert_close(query_fused, query_unfused)
371+
torch.testing.assert_close(key_fused, key_unfused)
372+
torch.testing.assert_close(value_fused, value_unfused)
373+
374+
if not isinstance(start_positions, torch.Tensor):
375+
torch.testing.assert_close(grad_fused, grad_unfused)

0 commit comments

Comments
 (0)