@@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {
357
357
358
358
TVM_FFI_REGISTER_GLOBAL (" runtime.GetCudaDeviceCount" ).set_body_typed(GetCudaDeviceCount);
359
359
360
+ /* *
361
+ * \brief FFI wrapper for cuTensorMapEncodeTiled.
362
+ *
363
+ * This function registers a global function `runtime.cuTensorMapEncodeTiled` that can be
364
+ * called from other parts of the TVM runtime (e.g., Python). It wraps the CUDA Driver API
365
+ * function `cuTensorMapEncodeTiled`, which initializes a tensor map descriptor (CUtensorMap).
366
+ *
367
+ * \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to be initialized.
368
+ * \param tensor_dtype (DataType): The TVM data type of the tensor.
369
+ * \param tensor_rank (int): The rank (number of dimensions) of the tensor.
370
+ * \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in global memory.
371
+ * \param global_shape (int...): `tensor_rank` integer arguments for the global tensor dimensions.
372
+ * \param global_strides (int...): `tensor_rank - 1` integer arguments for the global tensor
373
+ * strides. The stride for the innermost dimension is not provided as it's assumed to be contiguous.
374
+ * \param shared_shape (int...): `tensor_rank` integer arguments for the shape of the tile (box)
375
+ * in shared memory.
376
+ * \param shared_strides (int...): `tensor_rank` integer arguments for the strides of the tile (box)
377
+ * in shared memory.
378
+ * \param interleaved_kind (int): An integer corresponding to the CUtensorMapInterleave enum.
379
+ * \param swizzle_kind (int): An integer corresponding to the CUtensorMapSwizzle enum.
380
+ * \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum.
381
+ * \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum.
382
+ */
383
+ TVM_FFI_REGISTER_GLOBAL (" runtime.cuTensorMapEncodeTiled" )
384
+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
385
+ CHECK_GE (args.size (), 4 ) << " init_cuTensorMap expects at least 4 arguments" ;
386
+ size_t arg_cnt = 0 ;
387
+ CUtensorMap* tensor_map = static_cast <CUtensorMap*>(args[arg_cnt++].cast <void *>());
388
+ runtime::DataType tensor_dtype = args[arg_cnt++].cast <runtime::DataType>();
389
+ uint32_t tensor_rank = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
390
+ void * tensor_ptr = static_cast <void *>(args[arg_cnt++].cast <void *>());
391
+
392
+ CHECK_EQ (args.size (), 4 + tensor_rank * 4 + 3 )
393
+ << " cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments"
394
+ << " tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank
395
+ << " ), global_strides(" << tensor_rank - 1 << " ), shared_shape(" << tensor_rank
396
+ << " ), shared_strides(" << tensor_rank << " ), interleaved_kind, swizzle_kind"
397
+ << " , l2_promotion_kind, oob_fill_kind" ;
398
+
399
+ std::vector<cuuint64_t > global_shape (tensor_rank);
400
+ std::vector<cuuint64_t > global_strides (tensor_rank);
401
+ std::vector<uint32_t > shared_shape (tensor_rank);
402
+ std::vector<uint32_t > shared_strides (tensor_rank);
403
+ for (size_t i = 0 ; i < tensor_rank; ++i) {
404
+ global_shape[i] = static_cast <cuuint64_t >(args[arg_cnt++].cast <int64_t >());
405
+ }
406
+ for (size_t i = 0 ; i < tensor_rank - 1 ; ++i) {
407
+ global_strides[i] = static_cast <cuuint64_t >(args[arg_cnt++].cast <int64_t >());
408
+ CHECK_EQ (global_strides[i] % 16 , 0 ) << " global strides must be multiple of 16" ;
409
+ }
410
+ for (size_t i = 0 ; i < tensor_rank; ++i) {
411
+ shared_shape[i] = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
412
+ CHECK_GE (shared_shape[i], 0 ) << " boxDim must be non-negative" ;
413
+ CHECK_LE (shared_shape[i], 256 ) << " boxDim must be less than or equal to 256" ;
414
+ }
415
+ for (size_t i = 0 ; i < tensor_rank; ++i) {
416
+ shared_strides[i] = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
417
+ }
418
+ auto interleaved_kind = static_cast <CUtensorMapInterleave>(args[arg_cnt++].cast <int >());
419
+ auto swizzle_kind = static_cast <CUtensorMapSwizzle>(args[arg_cnt++].cast <int >());
420
+ auto l2_promotion_kind = static_cast <CUtensorMapL2promotion>(args[arg_cnt++].cast <int >());
421
+ auto oob_fill_kind = static_cast <CUtensorMapFloatOOBfill>(args[arg_cnt++].cast <int >());
422
+
423
+ ICHECK_EQ (tensor_dtype.lanes (), 1 )
424
+ << " Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
425
+ CUtensorMapDataType cu_dtype;
426
+ switch (tensor_dtype.code ()) {
427
+ case DataType::kInt :
428
+ // int
429
+ switch (tensor_dtype.bits ()) {
430
+ case 8 :
431
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
432
+ break ;
433
+ case 32 :
434
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
435
+ break ;
436
+ case 64 :
437
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
438
+ break ;
439
+ default :
440
+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
441
+ }
442
+ break ;
443
+ case DataType::kUInt :
444
+ // unsigned int
445
+ switch (tensor_dtype.bits ()) {
446
+ case 8 :
447
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
448
+ break ;
449
+ case 16 :
450
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
451
+ break ;
452
+ case 32 :
453
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
454
+ break ;
455
+ case 64 :
456
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
457
+ break ;
458
+ default :
459
+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
460
+ }
461
+ break ;
462
+ case DataType::kFloat :
463
+ // float
464
+ switch (tensor_dtype.bits ()) {
465
+ case 16 :
466
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
467
+ break ;
468
+ case 32 :
469
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
470
+ break ;
471
+ case 64 :
472
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
473
+ break ;
474
+ default :
475
+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
476
+ }
477
+ break ;
478
+ case DataType::kBFloat :
479
+ // bfloat
480
+ switch (tensor_dtype.bits ()) {
481
+ case 16 :
482
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
483
+ break ;
484
+ default :
485
+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
486
+ }
487
+ break ;
488
+ case DataType::kFloat8_e4m3fn :
489
+ // NV float8 e4m3
490
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
491
+ break ;
492
+ case DataType::kFloat8_e5m2 :
493
+ // NV float8 e5m2
494
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
495
+ break ;
496
+ default :
497
+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
498
+ }
499
+
500
+ // sanity checks per cuTensorMapEncodeTiled requirements
501
+ // see
502
+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
503
+ CHECK_EQ ((reinterpret_cast <uint64_t >(tensor_ptr) & 0b1111 ), 0 ); // 16-byte alignment
504
+ CHECK_EQ ((reinterpret_cast <uint64_t >(tensor_map) & 0b111111 ), 0 ); // 64-byte alignment
505
+ CHECK_LE (tensor_rank, 5 ) << " cuTensorMapEncodeTiled only supports up to 5D tensors" ;
506
+
507
+ if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
508
+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 32 )
509
+ << " CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32." ;
510
+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
511
+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 64 )
512
+ << " CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64." ;
513
+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
514
+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 128 )
515
+ << " CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= "
516
+ " 128." ;
517
+ }
518
+
519
+ const cuuint64_t * global_shape_ptr = global_shape.data ();
520
+ const cuuint64_t * global_strides_ptr = global_strides.data ();
521
+ const uint32_t * shared_shape_ptr = shared_shape.data ();
522
+ const uint32_t * shared_strides_ptr = shared_strides.data ();
523
+
524
+ CUresult res =
525
+ cuTensorMapEncodeTiled (tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr,
526
+ global_strides_ptr, shared_shape_ptr, shared_strides_ptr,
527
+ interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind);
528
+ const char * errstr;
529
+ cuGetErrorString (res, &errstr);
530
+ if (res != CUDA_SUCCESS) {
531
+ // get error string
532
+ const char * error_string = nullptr ;
533
+ cuGetErrorString (res, &error_string);
534
+ std::cerr << " Error in cuTensorMapEncodeTiled: " << error_string << std::endl;
535
+ std::cout << " cu_dtype: " << cu_dtype << " \n " ;
536
+ std::cout << " TMA Desc Addr: " << tensor_map << " \n " ;
537
+ std::cout << " TMA Interleave: " << interleaved_kind << " \n " ;
538
+ std::cout << " TMA L2Promotion: " << l2_promotion_kind << " \n " ;
539
+ std::cout << " TMA OOBFill: " << oob_fill_kind << " \n " ;
540
+ std::cout << " SMEM Swizzle: " << swizzle_kind << " \n " ;
541
+ std::cout << " tensor rank: " << tensor_rank << " \n " ;
542
+ std::cout << " global prob shape: " ;
543
+ for (size_t i = 0 ; i < tensor_rank; i++) {
544
+ std::cout << global_shape[i] << " " ;
545
+ }
546
+ std::cout << " \n " ;
547
+ std::cout << " global prob stride: " ;
548
+ for (size_t i = 0 ; i < tensor_rank; i++) {
549
+ std::cout << global_strides[i] << " " ;
550
+ }
551
+ std::cout << " \n " ;
552
+ std::cout << " smem box shape: " ;
553
+ for (size_t i = 0 ; i < tensor_rank; i++) {
554
+ std::cout << shared_shape[i] << " " ;
555
+ }
556
+ std::cout << " \n " ;
557
+ std::cout << " smem box stride: " ;
558
+ for (size_t i = 0 ; i < tensor_rank; i++) {
559
+ std::cout << shared_strides[i] << " " ;
560
+ }
561
+ std::cout << " \n " ;
562
+ CHECK_EQ (res, CUDA_SUCCESS) << " Error in cuTensorMapEncodeTiled: " << errstr;
563
+ }
564
+ });
565
+
360
566
} // namespace runtime
361
567
} // namespace tvm
0 commit comments