Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include <limits>

#include "../intrin_rule.h"

namespace tvm {
Expand Down Expand Up @@ -175,7 +177,15 @@ TVM_REGISTER_OP("tir.asin")
PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112);
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
return term1 + term3 + term5 + term7 + term9 + term11;
PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11;
/* --- domain limit check --- */
PrimExpr lower = make_const(x.dtype(), -1.0);
PrimExpr upper = make_const(x.dtype(), 1.0);
PrimExpr out_range = tir::Or(x<lower, x> upper);
// Use a quiet NaN constant
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());
// select: if out of [-1,1] → NaN, else → series
return tir::Select(out_range, nan_const, series);
});

TVM_REGISTER_OP("tir.acos")
Expand Down
17 changes: 17 additions & 0 deletions tests/python/tir-base/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,23 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):
func(a, b)
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol)

# Out‐of‐bounds test for asin/acos
name = tvm_intrin.__name__
if name in ("asin", "acos"):
# generate some values outside [-1, 1]
n = 8
out_np = np.concatenate(
[
np.random.uniform(1.1, 2.0, size=n // 2),
np.random.uniform(-2.0, -1.1, size=n // 2),
]
).astype(A.dtype)
a2 = tvm.nd.array(out_np, dev)
b2 = tvm.nd.array(np.empty_like(out_np), dev)
func(a2, b2)
# all outputs should be NaN
assert np.all(np.isnan(b2.numpy()))

for func in test_funcs:
atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5
run_test(*func, atol, rtol)
Expand Down