From 630ea0166492e011d697cf231746c3b1a754cefb Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Sat, 9 Sep 2023 16:22:23 +0100 Subject: [PATCH] fix(nmod): ZeroDivisionError instead of coredump --- src/flint/test/test.py | 13 +++++++--- src/flint/types/nmod.pyx | 56 ++++++++++++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/flint/test/test.py b/src/flint/test/test.py index 67263ca8..9c9dc8fe 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -1279,6 +1279,7 @@ def test_nmod(): assert G(1,2) != G(0,2) assert G(0,2) != G(0,3) assert G(3,5) == G(8,5) + assert G(1,2) != (1,2) assert isinstance(hash(G(3, 5)), int) assert raises(lambda: G([], 3), TypeError) #assert G(3,5) == 8 # do we want this? @@ -1304,14 +1305,20 @@ def test_nmod(): assert G(0,3) / G(1,3) == G(0,3) assert G(3,17) * flint.fmpq(11,5) == G(10,17) assert G(3,17) / flint.fmpq(11,5) == G(6,17) + assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError) + assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError) + assert raises(lambda: G(2,5) / 0, ZeroDivisionError) + assert G(1,6) / G(5,6) == G(5,6) + assert raises(lambda: G(1,6) / G(3,6), ZeroDivisionError) assert G(1,3) ** 2 == G(1,3) assert G(2,3) ** flint.fmpz(2) == G(1,3) + assert ~G(2,7) == G(2,7) ** -1 == G(4,7) + assert raises(lambda: G(3,6) ** -1, ZeroDivisionError) + assert raises(lambda: ~G(3,6), ZeroDivisionError) + assert raises(lambda: pow(G(1,3), 2, 7), TypeError) assert G(flint.fmpq(2, 3), 5) == G(4,5) assert raises(lambda: G(2,5) ** G(2,5), TypeError) assert raises(lambda: flint.fmpz(2) ** G(2,5), TypeError) - assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError) - assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError) - assert raises(lambda: G(2,5) / 0, ZeroDivisionError) assert raises(lambda: G(2,5) + G(2,7), ValueError) assert raises(lambda: G(2,5) - G(2,7), ValueError) assert raises(lambda: G(2,5) * G(2,7), ValueError) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 87abfb5b..c0a01bb8 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -5,12 +5,14 @@ from flint.types.fmpz cimport any_as_fmpz from flint.types.fmpz cimport fmpz from flint.types.fmpq cimport fmpq +from flint.flintlib.flint cimport ulong from flint.flintlib.fmpz cimport fmpz_t from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv from flint.flintlib.nmod_vec cimport * from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui from flint.flintlib.fmpq cimport fmpq_mod_fmpz +from flint.flintlib.ulong_extras cimport n_gcdinv cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: cdef int success @@ -64,9 +66,6 @@ cdef class nmod(flint_scalar): def __int__(self): return int(self.val) - def __long__(self): - return self.val - def modulus(self): return self.mod.n @@ -170,6 +169,8 @@ cdef class nmod(flint_scalar): cdef nmod r cdef mp_limb_t sval, tval, x cdef nmod_t mod + cdef ulong tinvval + if typecheck(s, nmod): mod = (s).mod sval = (s).val @@ -180,17 +181,19 @@ cdef class nmod(flint_scalar): tval = (t).val if not any_as_nmod(&sval, s, mod): return NotImplemented + if tval == 0: raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) if not s: return s - # XXX: check invertibility? - x = nmod_div(sval, tval, mod) - if x == 0: + + g = n_gcdinv(&tinvval, tval, mod.n) + if g != 1: raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + r = nmod.__new__(nmod) r.mod = mod - r.val = x + r.val = nmod_mul(sval, tinvval, mod) return r def __truediv__(s, t): @@ -200,18 +203,43 @@ cdef class nmod(flint_scalar): return nmod._div_(t, s) def __invert__(self): - return (1 / self) # XXX: speed up + cdef nmod r + cdef ulong g, inv, sval + sval = (self).val + g = n_gcdinv(&inv, sval, self.mod.n) + if g != 1: + raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n)) + r = nmod.__new__(nmod) + r.mod = self.mod + r.val = inv + return r - def __pow__(self, exp): + def __pow__(self, exp, modulus=None): cdef nmod r + cdef mp_limb_t rval, mod + cdef ulong g, rinv + + if modulus is not None: + raise TypeError("three-argument pow() not supported by nmod") + e = any_as_fmpz(exp) if e is NotImplemented: return NotImplemented - r = nmod.__new__(nmod) - r.mod = self.mod - r.val = self.val + + rval = (self).val + mod = (self).mod.n + + # XXX: It is not clear that it is necessary to special case negative + # exponents here. The nmod_pow_fmpz function seems to handle this fine + # but the Flint docs say that the exponent must be nonnegative. if e < 0: - r.val = nmod_inv(r.val, self.mod) + g = n_gcdinv(&rinv, rval, mod) + if g != 1: + raise ZeroDivisionError("%s is not invertible mod %s" % (rval, mod)) + rval = rinv e = -e - r.val = nmod_pow_fmpz(r.val, (e).val, self.mod) + + r = nmod.__new__(nmod) + r.mod = self.mod + r.val = nmod_pow_fmpz(rval, (e).val, self.mod) return r