diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index d1703e82a3..09299dc3f5 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -1063,14 +1063,11 @@ def vectorize_over_posterior( for rv in general_toposort( # type: ignore[call-overload] all_rvs, lambda x: x.owner.inputs if x.owner is not None else None ) - if rv in all_rvs + if rv in all_rvs and rv not in needed_rvs ]: - rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs]) - if ( - rv not in needed_rvs - and not ({*outputs, *independent_rvs} & set(rv_ancestors)) - and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs} - ): + blockers = [*needed_rvs, *independent_rvs, *outputs] + rv_ancestors = ancestors([rv], blockers=blockers) + if not (set(blockers) & set(rv_ancestors)): independent_rvs.append(rv) for rv in independent_rvs: replace_dict[rv] = change_dist_size(rv, new_size=batch_shape, expand=True) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 3dd30e14f7..beb93b97ca 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample(): atol=0.6 / np.sqrt(10000), ) assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1) + + +def test_vectorize_over_posterior_with_intermediate_rvs(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Normal.dist(a) + c = b + 1 + d = pm.Normal.dist(c) + idata = pm.sample_prior_predictive(100, var_names=["a"]) + idata.add_groups({"posterior": idata.prior}) + _, _, vectorized_no_intermediate = vectorize_over_posterior( + outputs=[b, c, d], + posterior=idata.posterior, + input_rvs=[a], + allow_rvs_in_graph=True, + ) + [vectorized_intermediate_rvs] = vectorize_over_posterior( + outputs=[d], + posterior=idata.posterior, + input_rvs=[a], + allow_rvs_in_graph=True, + ) + assert vectorized_no_intermediate.type.shape == (1, 100) + assert vectorized_no_intermediate.type.shape == vectorized_intermediate_rvs.type.shape + [a_ancestor1] = get_var_by_name([vectorized_no_intermediate], "a") + [a_ancestor2] = get_var_by_name([vectorized_intermediate_rvs], "a") + assert isinstance(a_ancestor1, TensorConstant) + assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data) + assert isinstance(a_ancestor2, TensorConstant) + assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)