Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,34 @@ class FuncType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};

/*!
* \brief The type of tensor map.
* \sa TensorMapType
*/
class TensorMapTypeNode : public TypeNode {
public:
void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not have to be now, see #18098 to transition to new reflection, i can do that after erge as well


bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const {
return equal(span, other->span);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); }

static constexpr const char* _type_key = "TensorMapType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode);
};

/*!
* \brief Managed reference to TensorMapTypeNode.
* \sa TensorMapTypeNode
*/
class TensorMapType : public Type {
public:
TVM_DLL TensorMapType(Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, TensorMapTypeNode);
};

} // namespace tvm
#endif // TVM_IR_TYPE_H_
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
}

inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); }

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = std::nullopt, bool is_size_var = false) { \
DataType dtype = DType; \
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,19 @@ def __init__(self, arg_types, ret_type):
arg_types,
ret_type,
)


@tvm.ffi.register_object("TensorMapType")
class TensorMapType(Type):
"""TensorMapType used in the low-level TIR.

Parameters
----------
span : tvm.ir.Span
The span information.
"""

def __init__(self, span=None):
self.__init_handle_by_constructor__(
_ffi_api.TensorMapType, span # pylint: disable=no-member
)
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,8 @@ def handle(
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
if dtype == "tensormap":
return _ffi_api.TensormapHandle() # type: ignore[attr-defined] # pylint: disable=no-member
is_unknown_type = dtype is None
if dtype is None:
dtype = "void"
Expand Down
12 changes: 12 additions & 0 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,16 @@ TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});

TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) {
return TensorMapType(span);
});

TensorMapType::TensorMapType(Span span) {
ObjectPtr<TensorMapTypeNode> n = make_object<TensorMapTypeNode>();
n->span = std::move(span);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(TensorMapTypeNode);

} // namespace tvm
206 changes: 206 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {

TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount);

/**
* \brief FFI wrapper for cuTensorMapEncodeTiled.
*
* This function registers a global function `runtime.cuTensorMapEncodeTiled` that can be
* called from other parts of the TVM runtime (e.g., Python). It wraps the CUDA Driver API
* function `cuTensorMapEncodeTiled`, which initializes a tensor map descriptor (CUtensorMap).
*
* \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to be initialized.
* \param tensor_dtype (DataType): The TVM data type of the tensor.
* \param tensor_rank (int): The rank (number of dimensions) of the tensor.
* \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in global memory.
* \param global_shape (int...): `tensor_rank` integer arguments for the global tensor dimensions.
* \param global_strides (int...): `tensor_rank - 1` integer arguments for the global tensor
* strides. The stride for the innermost dimension is not provided as it's assumed to be contiguous.
* \param shared_shape (int...): `tensor_rank` integer arguments for the shape of the tile (box)
* in shared memory.
* \param shared_strides (int...): `tensor_rank` integer arguments for the strides of the tile (box)
* in shared memory.
* \param interleaved_kind (int): An integer corresponding to the CUtensorMapInterleave enum.
* \param swizzle_kind (int): An integer corresponding to the CUtensorMapSwizzle enum.
* \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum.
* \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum.
*/
TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments";
size_t arg_cnt = 0;
CUtensorMap* tensor_map = static_cast<CUtensorMap*>(args[arg_cnt++].cast<void*>());
runtime::DataType tensor_dtype = args[arg_cnt++].cast<runtime::DataType>();
uint32_t tensor_rank = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
void* tensor_ptr = static_cast<void*>(args[arg_cnt++].cast<void*>());

CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3)
<< "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments"
<< "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank
<< "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank
<< "), shared_strides(" << tensor_rank << "), interleaved_kind, swizzle_kind"
<< ", l2_promotion_kind, oob_fill_kind";

std::vector<cuuint64_t> global_shape(tensor_rank);
std::vector<cuuint64_t> global_strides(tensor_rank);
std::vector<uint32_t> shared_shape(tensor_rank);
std::vector<uint32_t> shared_strides(tensor_rank);
for (size_t i = 0; i < tensor_rank; ++i) {
global_shape[i] = static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
}
for (size_t i = 0; i < tensor_rank - 1; ++i) {
global_strides[i] = static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16";
}
for (size_t i = 0; i < tensor_rank; ++i) {
shared_shape[i] = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative";
CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256";
}
for (size_t i = 0; i < tensor_rank; ++i) {
shared_strides[i] = static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
}
auto interleaved_kind = static_cast<CUtensorMapInterleave>(args[arg_cnt++].cast<int>());
auto swizzle_kind = static_cast<CUtensorMapSwizzle>(args[arg_cnt++].cast<int>());
auto l2_promotion_kind = static_cast<CUtensorMapL2promotion>(args[arg_cnt++].cast<int>());
auto oob_fill_kind = static_cast<CUtensorMapFloatOOBfill>(args[arg_cnt++].cast<int>());

ICHECK_EQ(tensor_dtype.lanes(), 1)
<< "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
CUtensorMapDataType cu_dtype;
switch (tensor_dtype.code()) {
case DataType::kInt:
// int
switch (tensor_dtype.bits()) {
case 8:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case 32:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
break;
case 64:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
break;
default:
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
}
break;
case DataType::kUInt:
// unsigned int
switch (tensor_dtype.bits()) {
case 8:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case 16:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 32:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
case 64:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
break;
default:
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
}
break;
case DataType::kFloat:
// float
switch (tensor_dtype.bits()) {
case 16:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
break;
case 32:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
break;
case 64:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
break;
default:
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
}
break;
case DataType::kBFloat:
// bfloat
switch (tensor_dtype.bits()) {
case 16:
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
break;
default:
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
}
break;
case DataType::kFloat8_e4m3fn:
// NV float8 e4m3
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case DataType::kFloat8_e5m2:
// NV float8 e5m2
cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype);
}

// sanity checks per cuTensorMapEncodeTiled requirements
// see
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
CHECK_EQ((reinterpret_cast<uint64_t>(tensor_ptr) & 0b1111), 0); // 16-byte alignment
CHECK_EQ((reinterpret_cast<uint64_t>(tensor_map) & 0b111111), 0); // 64-byte alignment
CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors";

if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32)
<< "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32.";
} else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64)
<< "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64.";
} else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128)
<< "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= "
"128.";
}

const cuuint64_t* global_shape_ptr = global_shape.data();
const cuuint64_t* global_strides_ptr = global_strides.data();
const uint32_t* shared_shape_ptr = shared_shape.data();
const uint32_t* shared_strides_ptr = shared_strides.data();

CUresult res =
cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr,
global_strides_ptr, shared_shape_ptr, shared_strides_ptr,
interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind);
const char* errstr;
cuGetErrorString(res, &errstr);
if (res != CUDA_SUCCESS) {
// get error string
const char* error_string = nullptr;
cuGetErrorString(res, &error_string);
std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << std::endl;
std::cout << "cu_dtype: " << cu_dtype << "\n";
std::cout << "TMA Desc Addr: " << tensor_map << "\n";
std::cout << "TMA Interleave: " << interleaved_kind << "\n";
std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n";
std::cout << "TMA OOBFill: " << oob_fill_kind << "\n";
std::cout << "SMEM Swizzle: " << swizzle_kind << "\n";
std::cout << "tensor rank: " << tensor_rank << "\n";
std::cout << "global prob shape: ";
for (size_t i = 0; i < tensor_rank; i++) {
std::cout << global_shape[i] << " ";
}
std::cout << "\n";
std::cout << "global prob stride: ";
for (size_t i = 0; i < tensor_rank; i++) {
std::cout << global_strides[i] << " ";
}
std::cout << "\n";
std::cout << "smem box shape: ";
for (size_t i = 0; i < tensor_rank; i++) {
std::cout << shared_shape[i] << " ";
}
std::cout << "\n";
std::cout << "smem box stride: ";
for (size_t i = 0; i < tensor_rank; i++) {
std::cout << shared_strides[i] << " ";
}
std::cout << "\n";
CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr;
}
});

} // namespace runtime
} // namespace tvm
2 changes: 1 addition & 1 deletion src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ ffi::Function CUDAModuleNode::GetFunction(const String& name,
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags);
}

Module CUDAModuleCreate(std::string data, std::string fmt,
Expand Down
13 changes: 13 additions & 0 deletions src/runtime/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
std::vector<int> iarg_extra_tags(arg_extra_tags.size());
for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
iarg_extra_tags[i] = static_cast<int>(arg_extra_tags[i]);
}
writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags);
writer->EndObject();
}

Expand All @@ -56,6 +61,12 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
helper.DeclareOptionalField("thread_axis_tags",
&launch_param_tags); // for backward compatibility
std::vector<int> iarg_extra_tags;
helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags);
arg_extra_tags.resize(iarg_extra_tags.size());
for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
arg_extra_tags[i] = static_cast<ArgExtraTags>(iarg_extra_tags[i]);
}
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
Expand All @@ -67,12 +78,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(launch_param_tags);
writer->Write(arg_extra_tags);
}

bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&launch_param_tags)) return false;
if (!reader->Read(&arg_extra_tags)) return false;
return true;
}

Expand Down
3 changes: 3 additions & 0 deletions src/runtime/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct FunctionInfo {
std::vector<DLDataType> arg_types;
std::vector<std::string> launch_param_tags;

enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 };
std::vector<ArgExtraTags> arg_extra_tags;

void Save(dmlc::JSONWriter* writer) const;
void Load(dmlc::JSONReader* reader);
void Save(dmlc::Stream* writer) const;
Expand Down
Loading
Loading