-
Notifications
You must be signed in to change notification settings - Fork 36
3D enabled implementation of ep.pp.filter_observations, ep.pp.filter_features #953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
52c7f61
1fd3be2
45897f9
6fa5d92
8b2cc9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -8,15 +8,17 @@ | |||||
|
||||||
import numpy as np | ||||||
import scanpy as sc | ||||||
import scanpy.logging as logger | ||||||
import scipy.sparse as sp | ||||||
from anndata import AnnData | ||||||
from ehrdata import EHRData | ||||||
from ehrdata.core.constants import MISSING_VALUES | ||||||
|
||||||
from ehrapy._compat import function_2D_only, use_ehrdata | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from collections.abc import Collection, Mapping | ||||||
|
||||||
from ehrdata import EHRData | ||||||
from numpy.typing import NDArray | ||||||
from scanpy.neighbors import KnnTransformerLike | ||||||
from scipy.sparse import spmatrix | ||||||
|
@@ -401,6 +403,225 @@ def neighbors( | |||||
) | ||||||
|
||||||
|
||||||
def filter_features( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your nice new functions are not yet visible in the docs; See https://github.com/theislab/ehrapy/blob/main/docs/api/preprocessing_index.md how to add them, then you can check in the readthedocs build if everything looks as it should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this is a new function, I'd be still in favor to add the anndata deprecation warning decorator to allow this to handle AnnData for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we try to use single-dispatch more consistently across our functions, see the normalization methods. This is helpful to enforce that our functions work regardless of whether a numpy array, a sparse array, or a dask array is passed. In this case, its enough to make it work for the numpy arrays as you do already - but adding a single-dispatching that raises NotImplementedErrors, should the passed array in |
||||||
edata: EHRData, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*, | ||||||
min_obs: int | None = None, | ||||||
max_obs: int | None = None, | ||||||
time_mode: Literal["all", "any", "proportion"] = "all", | ||||||
prop: float | None = None, | ||||||
copy: bool = False, | ||||||
) -> EHRData | None: # pragma: no cover | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
"""Filter features based on number of observations. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filter features based on missing data thresholds Might be more descriptive. But you can object here |
||||||
|
||||||
Keep only features which have at least `min_obs` observations | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep only features which have at least Specify once what a non-missing value is, following the comment above |
||||||
or/and have at most `max_obs` observations. | ||||||
When a longitudinal `EHRData` is passed, filtering can be done across time points. | ||||||
|
||||||
Only provide `min_obs` and/or `max_obs` per call. | ||||||
|
||||||
Args: | ||||||
edata: Central data object. | ||||||
min_obs: Minimum number of observations required for a feature to pass filtering. | ||||||
max_obs: Maximum number of observations allowed for a feature to pass filtering. | ||||||
time_mode: How to combine filtering criteria across the time axis. Options are: | ||||||
* `'all'` (default): The feature must pass the filtering criteria in all time points. | ||||||
* `'any'`: The feature must pass the filtering criteria in at least one time point. | ||||||
* `'proportion'`: The feature must pass the filtering criteria in at least a proportion `prop` of time points. For example, with `prop=0.3`, | ||||||
the feature must pass the filtering criteria in at least 30% of the time points. | ||||||
prop: Proportion of time points in which the feature must pass the filtering criteria. Only relevant if `time_mode='proportion'`. | ||||||
copy: Determines whether a copy is returned. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
|
||||||
Returns: | ||||||
Depending on `copy`, subsets and annotates the passed data object and returns `None` | ||||||
|
||||||
Examples: | ||||||
>>> import ehrapy as ep | ||||||
>>> edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6) | ||||||
>>> edata.R.shape | ||||||
(500, 45, 15) | ||||||
>>> ep.pp.filter_features(edata, min_obs=185, time_mode="all") | ||||||
>>> edata.R.shape | ||||||
(500, 18, 15) | ||||||
|
||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
if not isinstance(edata, EHRData): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or AnnData |
||||||
raise TypeError("Data object must be an EHRData object") | ||||||
|
||||||
data = edata.copy() if copy else edata | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we also allow for AnnData, could you add one additional check that the passed layer is 2 dimensional? |
||||||
|
||||||
lower_set = min_obs is not None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
upper_set = max_obs is not None | ||||||
|
||||||
if not (lower_set or upper_set): | ||||||
raise ValueError("You must provide at least one of 'min_obs' and 'max_obs'") | ||||||
|
||||||
if time_mode not in {"all", "any", "proportion"}: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm always a fan to test these exceptions as well, it can be a 2-line test function like we have it for other functions where we test invalid arguments |
||||||
raise ValueError(f"time_mode must be one of 'all', 'any', 'proportion', got {time_mode}") | ||||||
|
||||||
if time_mode == "proportion" and (prop is None or not (0 < prop <= 1)): | ||||||
raise ValueError("prop must be set to a value between 0 and 1 when time_mode is 'proportion'") | ||||||
|
||||||
obs_ax, _var_ax, _time_ax = 0, 1, 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you use |
||||||
|
||||||
threshold_min = min_obs | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, variables |
||||||
threshold_max = max_obs | ||||||
|
||||||
missing_mask = np.isin(data.R, MISSING_VALUES) | np.isnan(data.R) | ||||||
|
||||||
|
||||||
present = ~missing_mask | ||||||
counts = present.sum(axis=obs_ax) | ||||||
|
||||||
if threshold_max is not None and threshold_min is not None: | ||||||
pass_threshold = (threshold_min <= counts) & (counts <= threshold_max) | ||||||
elif threshold_min is not None: | ||||||
pass_threshold = counts >= threshold_min | ||||||
else: | ||||||
pass_threshold = counts <= threshold_max | ||||||
|
||||||
if time_mode == "all": | ||||||
feature_mask = pass_threshold.all(axis=1) | ||||||
elif time_mode == "any": | ||||||
feature_mask = pass_threshold.any(axis=1) | ||||||
elif time_mode == "proportion": | ||||||
if prop is None: | ||||||
raise ValueError("prop must be set when time_mode is 'proportion'") | ||||||
feature_mask = (pass_threshold.sum(axis=1) / pass_threshold.shape[0]) >= prop | ||||||
else: | ||||||
raise ValueError(f"Unknown time_mode: {time_mode}") | ||||||
|
||||||
number_per_feature = counts.sum(axis=1).astype(np.float64) | ||||||
|
||||||
n_filtered = int((~feature_mask).sum()) | ||||||
|
||||||
if n_filtered > 0: | ||||||
msg = f"filtered out {n_filtered} features that are measured " | ||||||
if threshold_min is not None: | ||||||
msg += f"less than {threshold_min} counts" | ||||||
else: | ||||||
msg += f"more than {threshold_max} counts" | ||||||
|
||||||
if time_mode == "proportion": | ||||||
msg += f" in less than {prop * 100:.1f}% of time points" | ||||||
else: | ||||||
msg += f" in {time_mode} time points" | ||||||
logger.info(msg) | ||||||
|
||||||
label = "n_obs_over_time" | ||||||
data.var[label] = number_per_feature | ||||||
data._inplace_subset_var(feature_mask) | ||||||
return data if copy else None | ||||||
|
||||||
|
||||||
def filter_observations( | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
edata: EHRData, | ||||||
*, | ||||||
min_vars: int | None = None, | ||||||
max_vars: int | None = None, | ||||||
time_mode: Literal["all", "any", "proportion"] = "all", | ||||||
prop: float | None = None, | ||||||
copy: bool = False, | ||||||
) -> EHRData | None: | ||||||
"""Filter observations based on numbers of variables (features/measurements). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many comments from above also apply to this function :) |
||||||
|
||||||
Keep only observations which have at least `min_vars` variables and/or at most `max_vars` variables. | ||||||
When a longitudinal `EHRData` is passed, filtering can be done across time points. | ||||||
|
||||||
Only provide `min_vars` and/or `max_vars` per call. | ||||||
|
||||||
Args: | ||||||
edata: Central data object. | ||||||
min_vars: Minimum number of variables required for an observation to pass filtering. | ||||||
max_vars: Maximum number of variables allowed for an observation to pass filtering. | ||||||
time_mode: How to combine filtering criteria across the time axis. Only relevant if an `EHRData` is passed. Options are: | ||||||
* `'all'` (default): The observation must pass the filtering criteria in all time points. | ||||||
* `'any'`: The observation must pass the filtering criteria in at least one time point. | ||||||
* `'proportion'`: The observation must pass the filtering criteria in at least a proportion `prop` of time points. For example, with `prop=0.3`, | ||||||
the observation must pass the filtering criteria in at least 30% of the time points. | ||||||
prop: Proportion of time points in which the observation must pass the filtering criteria. Only relevant if `time_mode='proportion'`. | ||||||
copy: Determines whether a copy is returned. | ||||||
|
||||||
Returns: | ||||||
Depending on `copy`, subsets and annotates the passed data object and returns `None` | ||||||
|
||||||
Examples: | ||||||
>>> import ehrapy as ep | ||||||
>>> edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6) | ||||||
>>> edata.R.shape | ||||||
(500, 45, 15) | ||||||
>>> ep.pp.filter_observations(edata, min_vars=10, time_mode="all") | ||||||
>>> edata.R.shape | ||||||
(477, 45, 15) | ||||||
|
||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
if not isinstance(edata, EHRData): | ||||||
raise TypeError("Data object must be an EHRData object") | ||||||
|
||||||
data = edata.copy() if copy else edata | ||||||
|
||||||
lower_set = min_vars is not None | ||||||
upper_set = max_vars is not None | ||||||
|
||||||
if not (lower_set or upper_set): | ||||||
raise ValueError("You must provide at least one of 'min_vars' and 'max_vars'") | ||||||
if time_mode not in {"all", "any", "proportion"}: | ||||||
raise ValueError(f"time_mode must be one of 'all', 'any', 'proportion', got {time_mode}") | ||||||
if time_mode == "proportion" and (prop is None or not (0 < prop <= 1)): | ||||||
raise ValueError("prop must be set to a value between 0 and 1 when time_mode is 'proportion'") | ||||||
|
||||||
threshold_min = min_vars | ||||||
threshold_max = max_vars | ||||||
|
||||||
_obs_ax, var_ax, _time_ax = 0, 1, 2 | ||||||
n_obs, n_vars, n_time = edata.R.shape | ||||||
per_time_vals = np.empty((n_obs, n_time), dtype=float) | ||||||
|
||||||
for t in range(n_time): | ||||||
|
||||||
sliced = data.R[:, :, t] | ||||||
|
||||||
missing_mask = np.isin(sliced, MISSING_VALUES) | np.isnan(sliced) | ||||||
|
||||||
present = ~missing_mask | ||||||
vals = present.sum(axis=var_ax) | ||||||
|
||||||
per_time_vals[:, t] = vals | ||||||
|
||||||
if threshold_min is not None and threshold_max is not None: | ||||||
masks_t = (per_time_vals >= float(threshold_min)) & (per_time_vals <= float(threshold_max)) | ||||||
elif threshold_min is not None: | ||||||
masks_t = per_time_vals >= float(threshold_min) | ||||||
elif threshold_max is not None: | ||||||
masks_t = per_time_vals <= float(threshold_max) | ||||||
|
||||||
if time_mode == "all": | ||||||
obs_mask = masks_t.all(axis=1) | ||||||
elif time_mode == "any": | ||||||
obs_mask = masks_t.any(axis=1) | ||||||
else: | ||||||
obs_mask = masks_t.mean(axis=1) >= float(prop) | ||||||
|
||||||
number_per_obs = per_time_vals.sum(axis=1).astype(np.float64) | ||||||
|
||||||
n_filtered = int((~obs_mask).sum()) | ||||||
if n_filtered > 0: | ||||||
msg = f"filtered out {n_filtered} observations that have" | ||||||
if threshold_min is not None: | ||||||
msg += f"less than {threshold_min} " + "features" | ||||||
else: | ||||||
msg += f"more than {threshold_max} " + "features" | ||||||
if time_mode == "proportion": | ||||||
msg += f" in < {prop * 100:.1f}% of time points" | ||||||
else: | ||||||
msg += f" in {time_mode} time points" | ||||||
|
||||||
logger.info(msg) | ||||||
|
||||||
label = "n_vars_over_time" | ||||||
data.obs[label] = number_per_obs | ||||||
data._inplace_subset_obs(obs_mask) | ||||||
return data if copy else None | ||||||
|
||||||
|
||||||
def _random_resample( | ||||||
label: str | np.ndarray, | ||||||
target: str = "balanced", | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from pathlib import Path | ||
|
||
import ehrdata as ed | ||
import numpy as np | ||
import pytest | ||
|
||
import ehrapy as ep | ||
|
||
CURRENT_DIR = Path(__file__).parent | ||
|
||
|
||
def test_filter_features_invalid_args_min_max_obs(): | ||
edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When layers are allowed, check that this works for layers. You can check in other test cases how this can be done with Also, it make the test more robust if we check that this works in the 2D as well as in the 3D |
||
|
||
# no threshold | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_features(edata) | ||
# invalid time_mode | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_features(edata, min_obs=185, time_mode="invalid_mode", copy=False) | ||
# invalid prop | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_features(edata, min_obs=185, time_mode="proportion", prop=2, copy=False) | ||
|
||
# min_obs filtering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be split into a separate test - its not doing what the test name says it is doing |
||
n_vars_before = edata.R.shape[1] | ||
ep.pp.filter_features(edata, min_obs=185, time_mode="all", copy=False) | ||
n_vars_after = edata.R.shape[1] | ||
assert n_vars_after < n_vars_before | ||
|
||
# max_obs filtering | ||
n_vars_before = edata.R.shape[1] | ||
ep.pp.filter_features(edata, max_obs=200, time_mode="all", copy=False) | ||
n_vars_after = edata.R.shape[1] | ||
assert n_vars_after < n_vars_before | ||
|
||
|
||
def test_filter_obs_invalid_args_min_max_vars(): | ||
edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6) | ||
|
||
# no threshold | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_observations(edata) | ||
# invalid time_mode | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_observations(edata, min_vars=10, time_mode="invalid_mode", copy=False) | ||
# invalid prop | ||
with pytest.raises(ValueError): | ||
ep.pp.filter_observations(edata, min_vars=10, time_mode="proportion", prop=2, copy=False) | ||
|
||
# min_vars filtering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
n_obs_before = edata.R.shape[0] | ||
ep.pp.filter_observations(edata, min_vars=10, time_mode="all", copy=False) | ||
n_obs_after = edata.R.shape[0] | ||
assert n_obs_after < n_obs_before | ||
|
||
# max_vars filtering | ||
n_obs_before = edata.R.shape[0] | ||
ep.pp.filter_observations(edata, max_vars=12, time_mode="all", copy=False) | ||
n_obs_after = edata.R.shape[0] | ||
assert n_obs_after < n_obs_before |
Uh oh!
There was an error while loading. Please reload this page.