Skip to content

Commit 14b732e

Browse files
committed
Lower the horizontal batched reduce of matrix with optimal inline VISA.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 92cff48 commit 14b732e

File tree

11 files changed

+478
-1
lines changed

11 files changed

+478
-1
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class TargetInfoBase {
6666
unsigned numLaneToReduce,
6767
unsigned interleave) const = 0;
6868

69+
virtual bool
70+
warpBatchReduce(RewriterBase &rewriter, Location loc,
71+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
72+
triton::ReduceOp op, unsigned numLaneToReduce,
73+
unsigned interleave) const = 0;
74+
6975
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7076
// Emits LLVM code with |rewriter| to print a message following the given
7177
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5454
"TRITON_INTEL_FAST_MATH",
5555
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
5656
"TRITON_INTEL_REDUCE_TRANSPOSE",
57+
"TRITON_INTEL_ENABLE_SIMD_REDUCE",
5758
// clang-format on
5859
};
5960

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os
2+
import re
3+
import numpy as np
4+
from numpy.random import RandomState
5+
import pytest
6+
import torch
7+
import pathlib
8+
9+
import triton
10+
from triton._internal_testing import numpy_random, to_numpy
11+
12+
MIN_GROUP_SIZE = torch.xpu.get_device_capability()['sub_group_sizes'][0]
13+
14+
os.environ["TRITON_INTEL_ENABLE_SIMD_REDUCE"] = "1"
15+
16+
17+
class DpasLayout:
18+
19+
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
20+
rep_cluster):
21+
self.repeatCount = repeatCount
22+
self.systolic_depth = systolic_depth
23+
self.execution_size = execution_size
24+
self.ops_per_chan = ops_per_chan
25+
self.threads_per_warp = threads_per_warp
26+
self.warps_per_cta = warps_per_cta
27+
self.rep_cluster = rep_cluster
28+
29+
def __str__(self):
30+
return f"#ttig.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
31+
32+
33+
class BlockedLayout:
34+
35+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
36+
cta_split_num=[1, 1], cta_order=[0, 1]):
37+
self.sz_per_thread = size_per_thread
38+
self.threads_per_warp = threads_per_warp
39+
self.warps_per_cta = warps_per_cta
40+
self.order = order
41+
self.ctas_per_cga = ctas_per_cga
42+
self.cta_split_num = cta_split_num
43+
self.cta_order = cta_order
44+
45+
def __str__(self):
46+
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
47+
48+
49+
layouts = [
50+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=16,
51+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
52+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
53+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
54+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
55+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
56+
BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0])
57+
]
58+
59+
if MIN_GROUP_SIZE == 16:
60+
# Add threads_per_warp=32 cases.
61+
layouts += [
62+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
63+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
64+
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0])
65+
]
66+
67+
68+
def warps_per_cta(layout, shape):
69+
return layout.warps_per_cta
70+
71+
72+
GPU_DIALECT = "ttg"
73+
74+
75+
@pytest.mark.parametrize("M, N", [[8, 64], [16, 64], [32, 64], [64, 64], [128, 64]])
76+
@pytest.mark.parametrize("src_layout", layouts)
77+
@pytest.mark.parametrize("dtype_str", ["float32", "float16"])
78+
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
79+
def test_horizontal_simd_reduce(M, N, src_layout, dtype_str, reduce_op, device, tmp_path: pathlib.Path):
80+
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
81+
arith_op = {
82+
"max": {"int32": "arith.maxsi", "float32": "arith.maxnumf", "float16": "arith.maxnumf"}, #
83+
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
84+
}[reduce_op][dtype_str]
85+
numpy_op = {"max": np.max, "sum": np.sum}[reduce_op]
86+
rdims_1d = f"{M}"
87+
rdims_2d = f"{M}x1"
88+
store_range = "%1"
89+
warps = src_layout.warps_per_cta
90+
threads_per_warp = int(np.prod(src_layout.threads_per_warp))
91+
num_warps = int(np.prod(warps))
92+
blocked = BlockedLayout([1, 1], [16, threads_per_warp // 16], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1])
93+
one_d_layout = BlockedLayout([1], [threads_per_warp], [num_warps], [0], [1], [1], [0])
94+
95+
ir = f"""
96+
#blocked = {blocked}
97+
#src = {src_layout}
98+
#one_d_layout = {one_d_layout}
99+
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32, "ttig.min_sg_size" = {MIN_GROUP_SIZE} }} {{
100+
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
101+
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
102+
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked>
103+
%2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked>
104+
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
105+
%4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked>
106+
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked>
107+
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
108+
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked>
109+
%8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
110+
%9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
111+
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked>
112+
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
113+
%12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src>
114+
%13 = "tt.reduce"(%12) ({{
115+
^bb0(%arg3: {ty}, %arg4: {ty}):
116+
%17 = {arith_op} %arg3, %arg4 : {ty}
117+
tt.reduce.return %17 : {ty}
118+
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
119+
%14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
120+
%15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
121+
%16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
122+
%17 = tt.expand_dims %16 {{axis = 1 : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked>
123+
tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
124+
tt.return
125+
}}
126+
}}
127+
"""
128+
129+
temp_file = tmp_path / "test_reduce_layouts.ttgir"
130+
temp_file.write_text(ir)
131+
kernel = triton.compile(str(temp_file))
132+
133+
rs = RandomState(17)
134+
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
135+
z_shape = (M, 1)
136+
z = np.zeros(z_shape).astype(dtype_str)
137+
138+
x_tri = torch.tensor(x, device=device)
139+
z_tri = torch.tensor(z, device=device)
140+
141+
kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri)
142+
z_ref = numpy_op(x, axis=1, keepdims=True)
143+
144+
if dtype_str == 'float16':
145+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
146+
else:
147+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
148+
149+
llir = kernel.asm['llir']
150+
assert re.search(r'call .* asm', llir), 'no inline visa in llir' # inline visa is used

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5656
triton::ReduceOp op, unsigned numLaneToReduce,
5757
unsigned interleave) const override;
5858

59+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
60+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
61+
triton::ReduceOp op, unsigned numLaneToReduce,
62+
unsigned interleave) const override {
63+
return false;
64+
};
65+
5966
std::string getMulhiFuncName(Type resultElementTy) const override;
6067

6168
void printf(RewriterBase &rewriter, Value formatStrStart,

third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,24 @@ struct XeInstrExecution {
315315
bool onlyAttachMLIRArgs{};
316316
};
317317

318+
enum XeArch {
319+
Xe = 0,
320+
Xe2 = 1,
321+
Xe3 = 2,
322+
};
323+
324+
struct XeVISAInstr : public XeInstrBase<XeVISAInstr> {
325+
using XeInstrBase<XeVISAInstr>::XeInstrBase;
326+
327+
static std::optional<std::string> getTypeName(Type scalarTy);
328+
static unsigned getGRFSizeInBytes(XeArch arch);
329+
static unsigned getExecMaskLaneNum(XeArch arch);
330+
};
331+
332+
std::string simdReduceAsm(std::string binOp, unsigned warpSize,
333+
unsigned numLaneToReduce, unsigned accSize,
334+
Type elemTy, XeArch arch);
335+
318336
} // namespace triton
319337
} // namespace mlir
320338

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h"
33
#include "triton/Dialect/Triton/IR/Dialect.h"
4+
#include <triton/Tools/Sys/GetEnv.hpp>
45

56
using namespace mlir;
67
using namespace mlir::triton;
@@ -176,6 +177,18 @@ struct ReduceOpConversion
176177
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
177178
unsigned threadOffsetOnReductionAxis =
178179
helper.getThreadOffsetOnReductionAxis();
180+
181+
bool simdReduce =
182+
triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_SIMD_REDUCE");
183+
184+
if (simdReduce) {
185+
auto ret = targetInfo.warpBatchReduce(rewriter, op.getLoc(), accs, op,
186+
sizeIntraWarps,
187+
threadOffsetOnReductionAxis);
188+
if (ret)
189+
return;
190+
}
191+
179192
for (auto it : accs) {
180193
const SmallVector<unsigned> &key = it.first;
181194
SmallVector<Value> &acc = accs[key];

0 commit comments

Comments
 (0)