Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions python/tvm/script/ir_builder/tir/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.
"""Triton kernel integration with TIR"""

from typing import Tuple, List, Union, Any, Dict
from typing import Any, Dict, List, Tuple, Union

import triton
from packaging import version
from triton.runtime.jit import type_canonicalisation_dict

from tvm import tir
from tvm.topi.utils import get_const_int
from tvm.runtime import Module
from tvm.topi.utils import get_const_int

from .external_kernel import BaseKernel


Expand Down Expand Up @@ -70,7 +73,12 @@ def compile_to_device_module(
: len(grid)
]
launch_args = [num_warps * 32] + list(grid)
kernel_arg_types = [arg.dtype for arg in kernel_args]
if version.parse(triton.__version__) >= version.parse("3.3.0"):
kernel_arg_types = [
arg.dtype if not isinstance(arg, int) else "int64" for arg in kernel_args
]
else:
kernel_arg_types = [arg.dtype for arg in kernel_args]
if triton_kernel.metadata.shared > 0:
# Add shared memory size to the launch arguments
launch_param_tags.append("tir.use_dyn_shared_memory")
Expand Down Expand Up @@ -98,6 +106,9 @@ def _generate_triton_kernel(
for i, arg in enumerate(args):
if kernel_params[i].is_constexpr:
constants[kernel_params[i].name] = get_const_int(arg)
if version.parse(triton.__version__) >= version.parse("3.3.0"):
signature[kernel_params[i].name] = "constexpr"
kernel_args.append(arg)
continue
if arg.dtype == "handle":
assert isinstance(arg, tir.Var)
Expand All @@ -110,6 +121,10 @@ def _generate_triton_kernel(

# TODO: Support default argument in the kernel
# TODO: Add specialization for aligned buffer pointers
source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature)
if version.parse(triton.__version__) >= version.parse("3.3.0"):
kwargs = {"constexprs": constants}
else:
kwargs = {"constants": constants}
source = triton.compiler.ASTSource(fn=func, signature=signature, **kwargs)
compiled = triton.compiler.compile(source, options=kwargs)
return compiled, kernel_args
Loading