Skip to content
46 changes: 41 additions & 5 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
Expand Down Expand Up @@ -2548,7 +2549,7 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(Gamma):
class ChiSquared(PositiveContinuous):
r"""
:math:`\chi^2` log-likelihood.

Expand Down Expand Up @@ -2583,13 +2584,48 @@ class ChiSquared(Gamma):

Parameters
----------
nu: int
nu: float
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

def __init__(self, nu, *args, **kwargs):
self.nu = nu = at.as_tensor_variable(floatX(nu))
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)
@classmethod
def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def logp(value, nu):
"""
Calculate log-probability of ChiSquared distribution at specified value.

Parameters
----------
value: numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor

Returns
-------
TensorVariable
"""
return Gamma.logp(value, nu / 2, 2)

def logcdf(value, nu):
"""
Compute the log of the cumulative distribution function for ChiSquared distribution
at the specified value.

Parameters
----------
value: numeric or np.ndarray or `TensorVariable`
Value(s) for which log CDF is calculated. If the log CDF for
multiple values are desired the values must be provided in a numpy
array or `TensorVariable`.
Returns
-------
TensorVariable
"""
return Gamma.logcdf(value, nu / 2, 2)


# TODO: Remove this once logpt for multiplication is working!
Expand Down
17 changes: 14 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,15 +1030,26 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_chi_squared(self):
def test_chisquared_logp(self):
self.check_logp(
ChiSquared,
Rplus,
{"nu": Rplusdunif},
{"nu": Rplus},
lambda value, nu: sp.chi2.logpdf(value, df=nu),
)

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to numerical issues",
)
def test_chisquared_logcdf(self):
self.check_logcdf(
ChiSquared,
Rplus,
{"nu": Rplus},
lambda value, nu: sp.chi2.logcdf(value, df=nu),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_wald_logp(self):
self.check_logp(
Expand Down
19 changes: 13 additions & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,6 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestChiSquared(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquared
params = {"nu": 2.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestExGaussian(BaseTestCases.BaseTestCase):
distribution = pm.ExGaussian
Expand Down Expand Up @@ -753,6 +747,19 @@ class TestInverseGammaMuSigma(BaseTestDistribution):
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestChiSquared(BaseTestDistribution):
pymc_dist = pm.ChiSquared
pymc_dist_params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestBinomial(BaseTestDistribution):
pymc_dist = pm.Binomial
pymc_dist_params = {"n": 100, "p": 0.33}
Expand Down