Skip to content

Commit c167d43

Browse files
committed
Add horizontal batched reduce.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 92cff48 commit c167d43

File tree

11 files changed

+451
-1
lines changed

11 files changed

+451
-1
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class TargetInfoBase {
6666
unsigned numLaneToReduce,
6767
unsigned interleave) const = 0;
6868

69+
virtual bool
70+
warpBatchReduce(RewriterBase &rewriter, Location loc,
71+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
72+
triton::ReduceOp op, unsigned numLaneToReduce,
73+
unsigned interleave) const = 0;
74+
6975
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7076
// Emits LLVM code with |rewriter| to print a message following the given
7177
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5454
"TRITON_INTEL_FAST_MATH",
5555
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
5656
"TRITON_INTEL_REDUCE_TRANSPOSE",
57+
"TRITON_INTEL_ENABLE_SIMD_REDUCE",
5758
// clang-format on
5859
};
5960

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import os
2+
import re
3+
import numpy as np
4+
from numpy.random import RandomState
5+
import pytest
6+
import torch
7+
import pathlib
8+
9+
import triton
10+
from triton._internal_testing import numpy_random, to_numpy
11+
12+
MIN_GROUP_SIZE = torch.xpu.get_device_capability()['sub_group_sizes'][0]
13+
14+
os.environ["TRITON_INTEL_ENABLE_SIMD_REDUCE"] = "1"
15+
16+
17+
class DpasLayout:
18+
19+
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
20+
rep_cluster):
21+
self.repeatCount = repeatCount
22+
self.systolic_depth = systolic_depth
23+
self.execution_size = execution_size
24+
self.ops_per_chan = ops_per_chan
25+
self.threads_per_warp = threads_per_warp
26+
self.warps_per_cta = warps_per_cta
27+
self.rep_cluster = rep_cluster
28+
29+
def __str__(self):
30+
return f"#ttig.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
31+
32+
33+
class BlockedLayout:
34+
35+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
36+
cta_split_num=[1, 1], cta_order=[0, 1]):
37+
self.sz_per_thread = size_per_thread
38+
self.threads_per_warp = threads_per_warp
39+
self.warps_per_cta = warps_per_cta
40+
self.order = order
41+
self.ctas_per_cga = ctas_per_cga
42+
self.cta_split_num = cta_split_num
43+
self.cta_order = cta_order
44+
45+
def __str__(self):
46+
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
47+
48+
49+
layouts = [
50+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
51+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
52+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
53+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
54+
BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0])
55+
]
56+
57+
if MIN_GROUP_SIZE == 16:
58+
# Add threads_per_warp=32 cases.
59+
layouts += [BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0])]
60+
61+
62+
def warps_per_cta(layout, shape):
63+
return layout.warps_per_cta
64+
65+
66+
GPU_DIALECT = "ttg"
67+
68+
69+
@pytest.mark.parametrize("M, N", [[8, 64], [16, 64], [32, 64], [64, 64], [128, 64]])
70+
@pytest.mark.parametrize("src_layout", layouts)
71+
@pytest.mark.parametrize("dtype_str", ["float32", "float16"])
72+
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
73+
def test_horizontal_simd_reduce(M, N, src_layout, dtype_str, reduce_op, device, tmp_path: pathlib.Path):
74+
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
75+
arith_op = {
76+
"max": {"int32": "arith.maxsi", "float32": "arith.maxnumf", "float16": "arith.maxnumf"}, #
77+
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
78+
}[reduce_op][dtype_str]
79+
numpy_op = {"max": np.max, "sum": np.sum}[reduce_op]
80+
rdims_1d = f"{M}"
81+
rdims_2d = f"{M}x1"
82+
store_range = "%1"
83+
warps = src_layout.warps_per_cta
84+
threads_per_warp = int(np.prod(src_layout.threads_per_warp))
85+
num_warps = int(np.prod(warps))
86+
blocked = BlockedLayout([1, 1], [16, threads_per_warp // 16], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1])
87+
one_d_layout = BlockedLayout([1], [threads_per_warp], [num_warps], [0], [1], [1], [0])
88+
89+
ir = f"""
90+
#blocked = {blocked}
91+
#src = {src_layout}
92+
#one_d_layout = {one_d_layout}
93+
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32, "ttig.min_sg_size" = {MIN_GROUP_SIZE} }} {{
94+
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
95+
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
96+
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked>
97+
%2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked>
98+
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
99+
%4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked>
100+
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked>
101+
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
102+
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked>
103+
%8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
104+
%9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
105+
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked>
106+
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
107+
%12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src>
108+
%13 = "tt.reduce"(%12) ({{
109+
^bb0(%arg3: {ty}, %arg4: {ty}):
110+
%17 = {arith_op} %arg3, %arg4 : {ty}
111+
tt.reduce.return %17 : {ty}
112+
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
113+
%14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
114+
%15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
115+
%16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
116+
%17 = tt.expand_dims %16 {{axis = 1 : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked>
117+
tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
118+
tt.return
119+
}}
120+
}}
121+
"""
122+
123+
temp_file = tmp_path / "test_reduce_layouts.ttgir"
124+
temp_file.write_text(ir)
125+
kernel = triton.compile(str(temp_file))
126+
127+
rs = RandomState(17)
128+
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
129+
z_shape = (M, 1)
130+
z = np.zeros(z_shape).astype(dtype_str)
131+
132+
x_tri = torch.tensor(x, device=device)
133+
z_tri = torch.tensor(z, device=device)
134+
135+
kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri)
136+
z_ref = numpy_op(x, axis=1, keepdims=True)
137+
138+
if dtype_str == 'float16':
139+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
140+
else:
141+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
142+
143+
llir = kernel.asm['llir']
144+
assert re.search(r'call .* asm', llir), 'no inline visa in llir' # inline visa is used

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5656
triton::ReduceOp op, unsigned numLaneToReduce,
5757
unsigned interleave) const override;
5858

59+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
60+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
61+
triton::ReduceOp op, unsigned numLaneToReduce,
62+
unsigned interleave) const override {
63+
return false;
64+
};
65+
5966
std::string getMulhiFuncName(Type resultElementTy) const override;
6067

6168
void printf(RewriterBase &rewriter, Value formatStrStart,

third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,23 @@ struct XeInstrExecution {
315315
bool onlyAttachMLIRArgs{};
316316
};
317317

318+
enum XeArch {
319+
Xe = 0,
320+
Xe2 = 1,
321+
Xe3 = 2,
322+
};
323+
324+
struct XeVISAInstr : public XeInstrBase<XeVISAInstr> {
325+
using XeInstrBase<XeVISAInstr>::XeInstrBase;
326+
327+
static std::optional<std::string> getTypeName(Type scalarTy);
328+
static unsigned getGRFSizeInBytes(XeArch arch);
329+
static unsigned getExecMaskLaneNum(XeArch arch);
330+
};
331+
332+
std::string simdReduceAsm(std::string binOp, unsigned warpSize,
333+
unsigned accSize, Type elemTy, XeArch arch);
334+
318335
} // namespace triton
319336
} // namespace mlir
320337

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h"
33
#include "triton/Dialect/Triton/IR/Dialect.h"
4+
#include <triton/Tools/Sys/GetEnv.hpp>
45

56
using namespace mlir;
67
using namespace mlir::triton;
@@ -176,6 +177,18 @@ struct ReduceOpConversion
176177
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
177178
unsigned threadOffsetOnReductionAxis =
178179
helper.getThreadOffsetOnReductionAxis();
180+
181+
bool simdReduce =
182+
triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_SIMD_REDUCE");
183+
184+
if (simdReduce) {
185+
auto ret = targetInfo.warpBatchReduce(rewriter, op.getLoc(), accs, op,
186+
sizeIntraWarps,
187+
threadOffsetOnReductionAxis);
188+
if (ret)
189+
return;
190+
}
191+
179192
for (auto it : accs) {
180193
const SmallVector<unsigned> &key = it.first;
181194
SmallVector<Value> &acc = accs[key];

0 commit comments

Comments
 (0)