diff --git a/python/tvm/script/ir_builder/tir/triton.py b/python/tvm/script/ir_builder/tir/triton.py index 2d37d93a6dd8..4f0f2cd5202d 100644 --- a/python/tvm/script/ir_builder/tir/triton.py +++ b/python/tvm/script/ir_builder/tir/triton.py @@ -16,13 +16,16 @@ # under the License. """Triton kernel integration with TIR""" -from typing import Tuple, List, Union, Any, Dict +from typing import Any, Dict, List, Tuple, Union import triton +from packaging import version from triton.runtime.jit import type_canonicalisation_dict + from tvm import tir -from tvm.topi.utils import get_const_int from tvm.runtime import Module +from tvm.topi.utils import get_const_int + from .external_kernel import BaseKernel @@ -70,7 +73,12 @@ def compile_to_device_module( : len(grid) ] launch_args = [num_warps * 32] + list(grid) - kernel_arg_types = [arg.dtype for arg in kernel_args] + if version.parse(triton.__version__) >= version.parse("3.3.0"): + kernel_arg_types = [ + arg.dtype if not isinstance(arg, int) else "int64" for arg in kernel_args + ] + else: + kernel_arg_types = [arg.dtype for arg in kernel_args] if triton_kernel.metadata.shared > 0: # Add shared memory size to the launch arguments launch_param_tags.append("tir.use_dyn_shared_memory") @@ -98,6 +106,9 @@ def _generate_triton_kernel( for i, arg in enumerate(args): if kernel_params[i].is_constexpr: constants[kernel_params[i].name] = get_const_int(arg) + if version.parse(triton.__version__) >= version.parse("3.3.0"): + signature[kernel_params[i].name] = "constexpr" + kernel_args.append(arg) continue if arg.dtype == "handle": assert isinstance(arg, tir.Var) @@ -110,6 +121,10 @@ def _generate_triton_kernel( # TODO: Support default argument in the kernel # TODO: Add specialization for aligned buffer pointers - source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature) + if version.parse(triton.__version__) >= version.parse("3.3.0"): + kwargs = {"constexprs": constants} + else: + kwargs = {"constants": constants} + source = triton.compiler.ASTSource(fn=func, signature=signature, **kwargs) compiled = triton.compiler.compile(source, options=kwargs) return compiled, kernel_args