Skip to content

Commit 6c9d4c3

Browse files
spectrometerHBHShiboXing
authored andcommitted
[Runtime] CutensorMap support (apache#18097)
This PR introduces Cutensor map support in the runtime module. It enables calling kernels whose arguments are cuTensorMap, these arguments are passed as handle(address) and associated with arg_extra_tags that indicate indicate it is tensor map. The TensorMap is allocated on stack with a runtime API
1 parent ee41813 commit 6c9d4c3

File tree

20 files changed

+396
-23
lines changed

20 files changed

+396
-23
lines changed

include/tvm/ir/type.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,5 +330,34 @@ class FuncType : public Type {
330330
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
331331
};
332332

333+
/*!
334+
* \brief The type of tensor map.
335+
* \sa TensorMapType
336+
*/
337+
class TensorMapTypeNode : public TypeNode {
338+
public:
339+
void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
340+
341+
bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const {
342+
return equal(span, other->span);
343+
}
344+
345+
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); }
346+
347+
static constexpr const char* _type_key = "TensorMapType";
348+
TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode);
349+
};
350+
351+
/*!
352+
* \brief Managed reference to TensorMapTypeNode.
353+
* \sa TensorMapTypeNode
354+
*/
355+
class TensorMapType : public Type {
356+
public:
357+
TVM_DLL TensorMapType(Span span = Span());
358+
359+
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, TensorMapTypeNode);
360+
};
361+
333362
} // namespace tvm
334363
#endif // TVM_IR_TYPE_H_

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
452452
return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
453453
}
454454

455+
inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); }
456+
455457
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
456458
inline PrimExpr FuncName(Optional<PrimExpr> expr = std::nullopt, bool is_size_var = false) { \
457459
DataType dtype = DType; \

python/tvm/ir/type.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,19 @@ def __init__(self, arg_types, ret_type):
107107
arg_types,
108108
ret_type,
109109
)
110+
111+
112+
@tvm.ffi.register_object("TensorMapType")
113+
class TensorMapType(Type):
114+
"""TensorMapType used in the low-level TIR.
115+
116+
Parameters
117+
----------
118+
span : tvm.ir.Span
119+
The span information.
120+
"""
121+
122+
def __init__(self, span=None):
123+
self.__init_handle_by_constructor__(
124+
_ffi_api.TensorMapType, span # pylint: disable=no-member
125+
)

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,8 @@ def handle(
15661566
res : PrimExpr
15671567
The new tir.Var with type handle or casted expression with type handle.
15681568
"""
1569+
if dtype == "tensormap":
1570+
return _ffi_api.TensormapHandle() # type: ignore[attr-defined] # pylint: disable=no-member
15691571
is_unknown_type = dtype is None
15701572
if dtype is None:
15711573
dtype = "void"

src/ir/type.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,16 @@ TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
8989
return TupleType(fields);
9090
});
9191

92+
TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) {
93+
return TensorMapType(span);
94+
});
95+
96+
TensorMapType::TensorMapType(Span span) {
97+
ObjectPtr<TensorMapTypeNode> n = make_object<TensorMapTypeNode>();
98+
n->span = std::move(span);
99+
data_ = std::move(n);
100+
}
101+
102+
TVM_REGISTER_NODE_TYPE(TensorMapTypeNode);
103+
92104
} // namespace tvm

src/runtime/cuda/cuda_device_api.cc

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {
357357

358358
TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount);
359359

360+
/**
361+
* \brief FFI wrapper for cuTensorMapEncodeTiled.
362+
*
363+
* This function registers a global function `runtime.cuTensorMapEncodeTiled` that can be
364+
* called from other parts of the TVM runtime (e.g., Python). It wraps the CUDA Driver API
365+
* function `cuTensorMapEncodeTiled`, which initializes a tensor map descriptor (CUtensorMap).
366+
*
367+
* \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to be initialized.
368+
* \param tensor_dtype (DataType): The TVM data type of the tensor.
369+
* \param tensor_rank (int): The rank (number of dimensions) of the tensor.
370+
* \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in global memory.
371+
* \param global_shape (int...): `tensor_rank` integer arguments for the global tensor dimensions.
372+
* \param global_strides (int...): `tensor_rank - 1` integer arguments for the global tensor
373+
* strides. The stride for the innermost dimension is not provided as it's assumed to be contiguous.
374+
* \param shared_shape (int...): `tensor_rank` integer arguments for the shape of the tile (box)
375+
* in shared memory.
376+
* \param shared_strides (int...): `tensor_rank` integer arguments for the strides of the tile (box)
377+
* in shared memory.
378+
* \param interleaved_kind (int): An integer corresponding to the CUtensorMapInterleave enum.
379+
* \param swizzle_kind (int): An integer corresponding to the CUtensorMapSwizzle enum.
380+
* \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum.
381+
* \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum.
382+
*/
383+
TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled")
384+
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
385+
CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments";
386+
size_t arg_cnt = 0;
387+
CUtensorMap* tensor_map = static_cast<CUtensorMap*>(args[arg_cnt++].cast<void*>());
388+
runtime::DataType tensor_dtype = args[arg_cnt++].cast<runtime::DataType>();
389+
uint32_t tensor_rank = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
390+
void* tensor_ptr = static_cast<void*>(args[arg_cnt++].cast<void*>());
391+
392+
CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3)
393+
<< "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments"
394+
<< "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank
395+
<< "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank
396+
<< "), shared_strides(" << tensor_rank << "), interleaved_kind, swizzle_kind"
397+
<< ", l2_promotion_kind, oob_fill_kind";
398+
399+
std::vector<cuuint64_t> global_shape(tensor_rank);
400+
std::vector<cuuint64_t> global_strides(tensor_rank);
401+
std::vector<uint32_t> shared_shape(tensor_rank);
402+
std::vector<uint32_t> shared_strides(tensor_rank);
403+
for (size_t i = 0; i < tensor_rank; ++i) {
404+
global_shape[i] = static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
405+
}
406+
for (size_t i = 0; i < tensor_rank - 1; ++i) {
407+
global_strides[i] = static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
408+
CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16";
409+
}
410+
for (size_t i = 0; i < tensor_rank; ++i) {
411+
shared_shape[i] = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
412+
CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative";
413+
CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256";
414+
}
415+
for (size_t i = 0; i < tensor_rank; ++i) {
416+
shared_strides[i] = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
417+
}
418+
auto interleaved_kind = static_cast<CUtensorMapInterleave>(args[arg_cnt++].cast<int>());
419+
auto swizzle_kind = static_cast<CUtensorMapSwizzle>(args[arg_cnt++].cast<int>());
420+
auto l2_promotion_kind = static_cast<CUtensorMapL2promotion>(args[arg_cnt++].cast<int>());
421+
auto oob_fill_kind = static_cast<CUtensorMapFloatOOBfill>(args[arg_cnt++].cast<int>());
422+
423+
ICHECK_EQ(tensor_dtype.lanes(), 1)
424+
<< "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
425+
CUtensorMapDataType cu_dtype;
426+
switch (tensor_dtype.code()) {
427+
case DataType::kInt:
428+
// int
429+
switch (tensor_dtype.bits()) {
430+
case 8:
431+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
432+
break;
433+
case 32:
434+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
435+
break;
436+
case 64:
437+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
438+
break;
439+
default:
440+
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
441+
}
442+
break;
443+
case DataType::kUInt:
444+
// unsigned int
445+
switch (tensor_dtype.bits()) {
446+
case 8:
447+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
448+
break;
449+
case 16:
450+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
451+
break;
452+
case 32:
453+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
454+
break;
455+
case 64:
456+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
457+
break;
458+
default:
459+
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
460+
}
461+
break;
462+
case DataType::kFloat:
463+
// float
464+
switch (tensor_dtype.bits()) {
465+
case 16:
466+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
467+
break;
468+
case 32:
469+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
470+
break;
471+
case 64:
472+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
473+
break;
474+
default:
475+
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
476+
}
477+
break;
478+
case DataType::kBFloat:
479+
// bfloat
480+
switch (tensor_dtype.bits()) {
481+
case 16:
482+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
483+
break;
484+
default:
485+
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
486+
}
487+
break;
488+
case DataType::kFloat8_e4m3fn:
489+
// NV float8 e4m3
490+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
491+
break;
492+
case DataType::kFloat8_e5m2:
493+
// NV float8 e5m2
494+
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
495+
break;
496+
default:
497+
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
498+
}
499+
500+
// sanity checks per cuTensorMapEncodeTiled requirements
501+
// see
502+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
503+
CHECK_EQ((reinterpret_cast<uint64_t>(tensor_ptr) & 0b1111), 0); // 16-byte alignment
504+
CHECK_EQ((reinterpret_cast<uint64_t>(tensor_map) & 0b111111), 0); // 64-byte alignment
505+
CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors";
506+
507+
if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
508+
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32)
509+
<< "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32.";
510+
} else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
511+
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64)
512+
<< "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64.";
513+
} else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
514+
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128)
515+
<< "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= "
516+
"128.";
517+
}
518+
519+
const cuuint64_t* global_shape_ptr = global_shape.data();
520+
const cuuint64_t* global_strides_ptr = global_strides.data();
521+
const uint32_t* shared_shape_ptr = shared_shape.data();
522+
const uint32_t* shared_strides_ptr = shared_strides.data();
523+
524+
CUresult res =
525+
cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr,
526+
global_strides_ptr, shared_shape_ptr, shared_strides_ptr,
527+
interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind);
528+
const char* errstr;
529+
cuGetErrorString(res, &errstr);
530+
if (res != CUDA_SUCCESS) {
531+
// get error string
532+
const char* error_string = nullptr;
533+
cuGetErrorString(res, &error_string);
534+
std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << std::endl;
535+
std::cout << "cu_dtype: " << cu_dtype << "\n";
536+
std::cout << "TMA Desc Addr: " << tensor_map << "\n";
537+
std::cout << "TMA Interleave: " << interleaved_kind << "\n";
538+
std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n";
539+
std::cout << "TMA OOBFill: " << oob_fill_kind << "\n";
540+
std::cout << "SMEM Swizzle: " << swizzle_kind << "\n";
541+
std::cout << "tensor rank: " << tensor_rank << "\n";
542+
std::cout << "global prob shape: ";
543+
for (size_t i = 0; i < tensor_rank; i++) {
544+
std::cout << global_shape[i] << " ";
545+
}
546+
std::cout << "\n";
547+
std::cout << "global prob stride: ";
548+
for (size_t i = 0; i < tensor_rank; i++) {
549+
std::cout << global_strides[i] << " ";
550+
}
551+
std::cout << "\n";
552+
std::cout << "smem box shape: ";
553+
for (size_t i = 0; i < tensor_rank; i++) {
554+
std::cout << shared_shape[i] << " ";
555+
}
556+
std::cout << "\n";
557+
std::cout << "smem box stride: ";
558+
for (size_t i = 0; i < tensor_rank; i++) {
559+
std::cout << shared_strides[i] << " ";
560+
}
561+
std::cout << "\n";
562+
CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr;
563+
}
564+
});
565+
360566
} // namespace runtime
361567
} // namespace tvm

src/runtime/cuda/cuda_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ ffi::Function CUDAModuleNode::GetFunction(const String& name,
266266
const FunctionInfo& info = it->second;
267267
CUDAWrappedFunc f;
268268
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
269-
return PackFuncVoidAddr(f, info.arg_types);
269+
return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags);
270270
}
271271

272272
Module CUDAModuleCreate(std::string data, std::string fmt,

src/runtime/file_utils.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
4545
writer->WriteObjectKeyValue("name", name);
4646
writer->WriteObjectKeyValue("arg_types", sarg_types);
4747
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
48+
std::vector<int> iarg_extra_tags(arg_extra_tags.size());
49+
for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
50+
iarg_extra_tags[i] = static_cast<int>(arg_extra_tags[i]);
51+
}
52+
writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags);
4853
writer->EndObject();
4954
}
5055

@@ -56,6 +61,12 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
5661
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
5762
helper.DeclareOptionalField("thread_axis_tags",
5863
&launch_param_tags); // for backward compatibility
64+
std::vector<int> iarg_extra_tags;
65+
helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags);
66+
arg_extra_tags.resize(iarg_extra_tags.size());
67+
for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
68+
arg_extra_tags[i] = static_cast<ArgExtraTags>(iarg_extra_tags[i]);
69+
}
5970
helper.ReadAllFields(reader);
6071
arg_types.resize(sarg_types.size());
6172
for (size_t i = 0; i < arg_types.size(); ++i) {
@@ -67,12 +78,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const {
6778
writer->Write(name);
6879
writer->Write(arg_types);
6980
writer->Write(launch_param_tags);
81+
writer->Write(arg_extra_tags);
7082
}
7183

7284
bool FunctionInfo::Load(dmlc::Stream* reader) {
7385
if (!reader->Read(&name)) return false;
7486
if (!reader->Read(&arg_types)) return false;
7587
if (!reader->Read(&launch_param_tags)) return false;
88+
if (!reader->Read(&arg_extra_tags)) return false;
7689
return true;
7790
}
7891

src/runtime/meta_data.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ struct FunctionInfo {
5959
std::vector<DLDataType> arg_types;
6060
std::vector<std::string> launch_param_tags;
6161

62+
enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 };
63+
std::vector<ArgExtraTags> arg_extra_tags;
64+
6265
void Save(dmlc::JSONWriter* writer) const;
6366
void Load(dmlc::JSONReader* reader);
6467
void Save(dmlc::Stream* writer) const;

0 commit comments

Comments
 (0)