Skip to content

Commit 1856d9e

Browse files
committed
[Codegen] Add asin domain check and align NaN‐handling in pooling with CUDA semantics
- Update `tir.asin` to return quiet NaN if the input is outside of [-1, 1]. - Update LLVM codegen for `max/min` (used in pooling) to align CPU behavior with CUDA when handling NaN values. - Modified the regex in the aarch64 codegen test code to also match NaN-suppressing fminnm/fmaxnm as well as fmin/fmax
1 parent 3db71bb commit 1856d9e

File tree

5 files changed

+82
-3
lines changed

5 files changed

+82
-3
lines changed

src/target/llvm/codegen_llvm.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,12 +1623,24 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
16231623
llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
16241624
llvm::Value* a = MakeValue(op->a);
16251625
llvm::Value* b = MakeValue(op->b);
1626+
1627+
// IEEE-754 minNum keeps the numeric value if one input is NaN
1628+
if (op->a.dtype().is_float() && op->a.dtype().bits() >= 32) {
1629+
return builder_->CreateBinaryIntrinsic(llvm::Intrinsic::minnum, a, b);
1630+
}
1631+
// For integer types
16261632
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
16271633
}
16281634

16291635
llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
16301636
llvm::Value* a = MakeValue(op->a);
16311637
llvm::Value* b = MakeValue(op->b);
1638+
1639+
// IEEE-754 maxNum keeps the numeric value if one input is NaN
1640+
if (op->a.dtype().is_float() && op->a.dtype().bits() >= 32) {
1641+
return builder_->CreateBinaryIntrinsic(llvm::Intrinsic::maxnum, a, b);
1642+
}
1643+
// For integer types
16321644
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
16331645
}
16341646

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include <tvm/tir/op.h>
3131
#include <tvm/tir/op_attr_types.h>
3232

33+
#include <limits>
34+
3335
#include "../intrin_rule.h"
3436

3537
namespace tvm {
@@ -175,7 +177,15 @@ TVM_REGISTER_OP("tir.asin")
175177
PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112);
176178
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
177179
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
178-
return term1 + term3 + term5 + term7 + term9 + term11;
180+
PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11;
181+
/* --- domain limit check --- */
182+
PrimExpr lower = make_const(x.dtype(), -1.0);
183+
PrimExpr upper = make_const(x.dtype(), 1.0);
184+
PrimExpr out_range = tir::Or(x<lower, x> upper);
185+
// Use a quiet NaN constant
186+
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());
187+
// select: if out of [-1,1] → NaN, else → series
188+
return tir::Select(out_range, nan_const, series);
179189
});
180190

181191
TVM_REGISTER_OP("tir.acos")

tests/python/codegen/test_target_codegen_aarch64.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def check_correct_assembly(type):
184184
)
185185
select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly)
186186
max = re.findall(
187-
r"max\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
187+
r"f?max(?:nm)?\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
188188
)
189189

190190
assert len(loads) > 1
@@ -220,7 +220,7 @@ def check_correct_assembly(type):
220220
)
221221
select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly)
222222
min = re.findall(
223-
r"min\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
223+
r"f?min(?:nm)?\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
224224
)
225225

226226
assert len(loads) > 1

tests/python/tir-base/test_tir_intrin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,23 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):
100100
func(a, b)
101101
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol)
102102

103+
# Out‐of‐bounds test for asin/acos
104+
name = tvm_intrin.__name__
105+
if name in ("asin", "acos"):
106+
# generate some values outside [-1, 1]
107+
n = 8
108+
out_np = np.concatenate(
109+
[
110+
np.random.uniform(1.1, 2.0, size=n // 2),
111+
np.random.uniform(-2.0, -1.1, size=n // 2),
112+
]
113+
).astype(A.dtype)
114+
a2 = tvm.nd.array(out_np, dev)
115+
b2 = tvm.nd.array(np.empty_like(out_np), dev)
116+
func(a2, b2)
117+
# all outputs should be NaN
118+
assert np.all(np.isnan(b2.numpy()))
119+
103120
for func in test_funcs:
104121
atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5
105122
run_test(*func, atol, rtol)

tests/python/tir-transform/test_tir_transform_lower_intrin.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,46 @@ def test_lower_floormod():
117117
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
118118

119119

120+
# Max / Min NaN-handling (IEEE-754 maxNum / minNum)
121+
@tvm.testing.requires_llvm
122+
def test_lower_maxmin_nan():
123+
def get_fp_data():
124+
# covers (a, b), (a, NaN), (NaN, b), (NaN, NaN)
125+
x_vals = [-3.0, 0.0, 7.5, np.nan]
126+
y_vals = [2.0, np.nan, -8.0, np.nan]
127+
return list(zip(x_vals, y_vals))
128+
129+
data = get_fp_data()
130+
for dtype in ["float32", "float64"]:
131+
x = te.var("x", dtype=dtype)
132+
y = te.var("y", dtype=dtype)
133+
134+
res_max = lower_intrin([x, y], tvm.te.max(x, y))
135+
136+
def ref_max(a, b):
137+
# IEEE-754 maxNum semantics
138+
if np.isnan(a):
139+
return b
140+
if np.isnan(b):
141+
return a
142+
return max(a, b)
143+
144+
check_value(res_max, x, y, data, ref_max)
145+
146+
res_min = lower_intrin([x, y], tvm.te.min(x, y))
147+
148+
def ref_min(a, b):
149+
# IEEE-754 minNum semantics
150+
if np.isnan(a):
151+
return b
152+
if np.isnan(b):
153+
return a
154+
return min(a, b)
155+
156+
check_value(res_min, x, y, data, ref_min)
157+
158+
120159
if __name__ == "__main__":
121160
test_lower_floordiv()
122161
test_lower_floormod()
162+
test_lower_maxmin_nan()

0 commit comments

Comments
 (0)