Skip to content

Commit 0c48286

Browse files
vacu9708ShiboXing
authored andcommitted
[Codegen] Resolve issue apache#17965 where the same model produces different outputs on the LLVM (CPU) and CUDA (GPU) backends (apache#17985)
[Codegen] Add asin domain check - Update `tir.asin` to return quiet NaN if the input is outside of [-1, 1]
1 parent 0bf74c1 commit 0c48286

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include <tvm/tir/op.h>
3131
#include <tvm/tir/op_attr_types.h>
3232

33+
#include <limits>
34+
3335
#include "../intrin_rule.h"
3436

3537
namespace tvm {
@@ -175,7 +177,15 @@ TVM_REGISTER_OP("tir.asin")
175177
PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112);
176178
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
177179
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
178-
return term1 + term3 + term5 + term7 + term9 + term11;
180+
PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11;
181+
/* --- domain limit check --- */
182+
PrimExpr lower = make_const(x.dtype(), -1.0);
183+
PrimExpr upper = make_const(x.dtype(), 1.0);
184+
PrimExpr out_range = tir::Or(x<lower, x> upper);
185+
// Use a quiet NaN constant
186+
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());
187+
// select: if out of [-1,1] → NaN, else → series
188+
return tir::Select(out_range, nan_const, series);
179189
});
180190

181191
TVM_REGISTER_OP("tir.acos")

tests/python/tir-base/test_tir_intrin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,23 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):
100100
func(a, b)
101101
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol)
102102

103+
# Out‐of‐bounds test for asin/acos
104+
name = tvm_intrin.__name__
105+
if name in ("asin", "acos"):
106+
# generate some values outside [-1, 1]
107+
n = 8
108+
out_np = np.concatenate(
109+
[
110+
np.random.uniform(1.1, 2.0, size=n // 2),
111+
np.random.uniform(-2.0, -1.1, size=n // 2),
112+
]
113+
).astype(A.dtype)
114+
a2 = tvm.nd.array(out_np, dev)
115+
b2 = tvm.nd.array(np.empty_like(out_np), dev)
116+
func(a2, b2)
117+
# all outputs should be NaN
118+
assert np.all(np.isnan(b2.numpy()))
119+
103120
for func in test_funcs:
104121
atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5
105122
run_test(*func, atol, rtol)

0 commit comments

Comments
 (0)