|
1 | 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 | 2 | #
|
3 | 3 | # See LICENSE for license information.
|
4 |
| -from typing import Callable, Tuple, Union |
| 4 | +from typing import Callable, Tuple, Union, List |
5 | 5 | import math
|
6 | 6 | import torch
|
7 | 7 | import pytest
|
8 | 8 | from transformer_engine.pytorch.attention.rope import (
|
9 | 9 | RotaryPositionEmbedding,
|
10 | 10 | apply_rotary_pos_emb,
|
| 11 | + apply_fused_qkv_rotary_pos_emb, |
11 | 12 | )
|
12 | 13 |
|
13 | 14 |
|
14 | 15 | # Gradient is a broadcasted scalar
|
15 |
| -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: |
16 |
| - return output.sum() * 2 |
| 16 | +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: |
| 17 | + if isinstance(output, List): |
| 18 | + return sum(t.sum() * 2 for t in output) |
| 19 | + else: |
| 20 | + return output.sum() * 2 |
17 | 21 |
|
18 | 22 |
|
19 | 23 | # Gradient is a full tensor
|
20 |
| -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: |
21 |
| - t = torch.ones_like(output) |
22 |
| - return torch.sum(output * t) |
| 24 | +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: |
| 25 | + if isinstance(output, List): |
| 26 | + return sum(torch.sum(t * torch.ones_like(t)) for t in output) |
| 27 | + else: |
| 28 | + t = torch.ones_like(output) |
| 29 | + return torch.sum(output * t) |
23 | 30 |
|
24 | 31 |
|
25 | 32 | @pytest.mark.parametrize("start_positions", [True, False])
|
@@ -238,3 +245,131 @@ def test_fused_rope_thd(
|
238 | 245 | torch.testing.assert_close(grad_fused, grad_unfused)
|
239 | 246 |
|
240 | 247 | assert output_fused.is_contiguous()
|
| 248 | + |
| 249 | + |
| 250 | +@pytest.mark.parametrize("start_positions", [True, False]) |
| 251 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) |
| 252 | +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) |
| 253 | +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) |
| 254 | +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) |
| 255 | +@pytest.mark.parametrize("margin", [0, 10]) |
| 256 | +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) |
| 257 | +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) |
| 258 | +@pytest.mark.parametrize("cp_size", [1, 2]) |
| 259 | +@pytest.mark.parametrize("interleaved", [True, False]) |
| 260 | +def test_fused_qkv_rope( |
| 261 | + dtype: torch.dtype, |
| 262 | + seq_length: int, |
| 263 | + hidden_size: int, |
| 264 | + rotary_percent: float, |
| 265 | + margin: int, |
| 266 | + tensor_format: str, |
| 267 | + loss_func: Callable, |
| 268 | + cp_size: int, |
| 269 | + interleaved: bool, |
| 270 | + start_positions: bool, |
| 271 | +) -> None: |
| 272 | + if margin == 0 and start_positions == True: |
| 273 | + # This makes sure that the `start_positions` offsets being applied |
| 274 | + # are with the maximum length of the rope embeddings. |
| 275 | + pytest.skip("Skipping test with margin=0 and start_positions=True") |
| 276 | + |
| 277 | + if start_positions == True and cp_size > 1: |
| 278 | + # `start_positions` is only supported for `cp_size=1` and inference. |
| 279 | + pytest.skip("Skipping test with cp_size>1 and start_positions=True") |
| 280 | + |
| 281 | + if seq_length - margin < 0: |
| 282 | + pytest.skip("Skipping test with seq_length - margin < 0") |
| 283 | + |
| 284 | + device = torch.device("cuda:0") |
| 285 | + batch_size, head_num = 2, 64 |
| 286 | + |
| 287 | + t = torch.rand( |
| 288 | + (seq_length - margin, batch_size, head_num, hidden_size * 6), |
| 289 | + dtype=dtype, |
| 290 | + device=device, |
| 291 | + ) |
| 292 | + |
| 293 | + # Get arbitrary offsets to be used with RoPE for all the sequences |
| 294 | + start_positions = ( |
| 295 | + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) |
| 296 | + if start_positions |
| 297 | + else None |
| 298 | + ) |
| 299 | + |
| 300 | + if tensor_format == "bshd": |
| 301 | + t = t.transpose(0, 1).contiguous() |
| 302 | + t.requires_grad = True |
| 303 | + |
| 304 | + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) |
| 305 | + emb_q = rotary_pos_emb_q(seq_length * cp_size) |
| 306 | + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) |
| 307 | + emb_k = rotary_pos_emb_k(seq_length * cp_size) |
| 308 | + |
| 309 | + for cp_rank in range(cp_size): |
| 310 | + # unfused |
| 311 | + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 |
| 312 | + # for more accurate comparison |
| 313 | + |
| 314 | + t_clone = t.clone() |
| 315 | + (query, key, value) = torch.split( |
| 316 | + t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 |
| 317 | + ) |
| 318 | + query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) |
| 319 | + |
| 320 | + query_unfused = apply_rotary_pos_emb( |
| 321 | + query, |
| 322 | + emb_q, |
| 323 | + tensor_format=tensor_format, |
| 324 | + start_positions=start_positions, |
| 325 | + interleaved=interleaved, |
| 326 | + fused=True, |
| 327 | + cp_size=cp_size, |
| 328 | + cp_rank=cp_rank, |
| 329 | + ).to(dtype) |
| 330 | + |
| 331 | + key_unfused = apply_rotary_pos_emb( |
| 332 | + key, |
| 333 | + emb_k, |
| 334 | + tensor_format=tensor_format, |
| 335 | + start_positions=start_positions, |
| 336 | + interleaved=interleaved, |
| 337 | + fused=True, |
| 338 | + cp_size=cp_size, |
| 339 | + cp_rank=cp_rank, |
| 340 | + ).to(dtype) |
| 341 | + |
| 342 | + value_unfused = value |
| 343 | + loss_unfused = loss_func([query_unfused, key_unfused, value_unfused]) |
| 344 | + |
| 345 | + if not isinstance(start_positions, torch.Tensor): |
| 346 | + loss_unfused.backward() |
| 347 | + grad_unfused = t.grad.detach().clone() |
| 348 | + |
| 349 | + t.grad = None |
| 350 | + |
| 351 | + # fused |
| 352 | + query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb( |
| 353 | + t, |
| 354 | + emb_q, |
| 355 | + emb_k, |
| 356 | + tensor_format=tensor_format, |
| 357 | + start_positions=start_positions, |
| 358 | + interleaved=interleaved, |
| 359 | + cp_size=cp_size, |
| 360 | + cp_rank=cp_rank, |
| 361 | + qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size], |
| 362 | + ) |
| 363 | + loss_fused = loss_func([query_fused, key_fused, value_fused]) |
| 364 | + |
| 365 | + if not isinstance(start_positions, torch.Tensor): |
| 366 | + loss_fused.backward() |
| 367 | + grad_fused = t.grad.detach().clone() |
| 368 | + t.grad = None |
| 369 | + |
| 370 | + torch.testing.assert_close(query_fused, query_unfused) |
| 371 | + torch.testing.assert_close(key_fused, key_unfused) |
| 372 | + torch.testing.assert_close(value_fused, value_unfused) |
| 373 | + |
| 374 | + if not isinstance(start_positions, torch.Tensor): |
| 375 | + torch.testing.assert_close(grad_fused, grad_unfused) |
0 commit comments