Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 222 additions & 1 deletion ehrapy/preprocessing/_scanpy_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

import numpy as np
import scanpy as sc
import scanpy.logging as logger
import scipy.sparse as sp
from anndata import AnnData
from ehrdata import EHRData
from ehrdata.core.constants import MISSING_VALUES

from ehrapy._compat import function_2D_only, use_ehrdata

if TYPE_CHECKING:
from collections.abc import Collection, Mapping

from ehrdata import EHRData
from numpy.typing import NDArray
from scanpy.neighbors import KnnTransformerLike
from scipy.sparse import spmatrix
Expand Down Expand Up @@ -401,6 +403,225 @@ def neighbors(
)


def filter_features(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your nice new functions are not yet visible in the docs; See https://github.com/theislab/ehrapy/blob/main/docs/api/preprocessing_index.md how to add them, then you can check in the readthedocs build if everything looks as it should

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is a new function, I'd be still in favor to add the anndata deprecation warning decorator to allow this to handle AnnData for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we try to use single-dispatch more consistently across our functions, see the normalization methods. This is helpful to enforce that our functions work regardless of whether a numpy array, a sparse array, or a dask array is passed.

In this case, its enough to make it work for the numpy arrays as you do already - but adding a single-dispatching that raises NotImplementedErrors, should the passed array in edata be a scipy.sparse or a dask array, will help users to immediately notice that they'll have to convert their data to a numpy array if they want to use this function

edata: EHRData,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
edata: EHRData,
edata: EHRData | AnnData

*,
min_obs: int | None = None,
max_obs: int | None = None,
time_mode: Literal["all", "any", "proportion"] = "all",
prop: float | None = None,
copy: bool = False,
) -> EHRData | None: # pragma: no cover
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> EHRData | None: # pragma: no cover
) -> EHRData | AnnData | None: # pragma: no cover

"""Filter features based on number of observations.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filter features based on missing data thresholds

Might be more descriptive. But you can object here


Keep only features which have at least `min_obs` observations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep only features which have at least min_obs observations, meaning non-missing values.

Specify once what a non-missing value is, following the comment above

or/and have at most `max_obs` observations.
When a longitudinal `EHRData` is passed, filtering can be done across time points.

Only provide `min_obs` and/or `max_obs` per call.

Args:
edata: Central data object.
min_obs: Minimum number of observations required for a feature to pass filtering.
max_obs: Maximum number of observations allowed for a feature to pass filtering.
time_mode: How to combine filtering criteria across the time axis. Options are:
* `'all'` (default): The feature must pass the filtering criteria in all time points.
* `'any'`: The feature must pass the filtering criteria in at least one time point.
* `'proportion'`: The feature must pass the filtering criteria in at least a proportion `prop` of time points. For example, with `prop=0.3`,
the feature must pass the filtering criteria in at least 30% of the time points.
prop: Proportion of time points in which the feature must pass the filtering criteria. Only relevant if `time_mode='proportion'`.
copy: Determines whether a copy is returned.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer argument is missing. I think it's suitable to have, right? I might be overlooking something though.


Returns:
Depending on `copy`, subsets and annotates the passed data object and returns `None`

Examples:
>>> import ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6)
>>> edata.R.shape
(500, 45, 15)
>>> ep.pp.filter_features(edata, min_obs=185, time_mode="all")
>>> edata.R.shape
(500, 18, 15)

"""
if not isinstance(edata, EHRData):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or AnnData

raise TypeError("Data object must be an EHRData object")

data = edata.copy() if copy else edata
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we also allow for AnnData, could you add one additional check that the passed layer is 2 dimensional?


lower_set = min_obs is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower_set and upper_set are only used once; in that case you can omit defining these variables, and add min_obs is not None and max_obs is not None directly into the if statement

upper_set = max_obs is not None

if not (lower_set or upper_set):
raise ValueError("You must provide at least one of 'min_obs' and 'max_obs'")

if time_mode not in {"all", "any", "proportion"}:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always a fan to test these exceptions as well, it can be a 2-line test function like we have it for other functions where we test invalid arguments

raise ValueError(f"time_mode must be one of 'all', 'any', 'proportion', got {time_mode}")

if time_mode == "proportion" and (prop is None or not (0 < prop <= 1)):
raise ValueError("prop must be set to a value between 0 and 1 when time_mode is 'proportion'")

obs_ax, _var_ax, _time_ax = 0, 1, 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you use obs_ax and _var_ax sometimes, but sometimes just pass the numbers. Sticking to the numbers here and not have dedicated variables would be more consistent in how we do it across the package :)


threshold_min = min_obs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, variables threshold_* I think only used once

threshold_max = max_obs

missing_mask = np.isin(data.R, MISSING_VALUES) | np.isnan(data.R)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I make a note to myself that np.nan should be added to MISSING_VALUES


present = ~missing_mask
counts = present.sum(axis=obs_ax)

if threshold_max is not None and threshold_min is not None:
pass_threshold = (threshold_min <= counts) & (counts <= threshold_max)
elif threshold_min is not None:
pass_threshold = counts >= threshold_min
else:
pass_threshold = counts <= threshold_max

if time_mode == "all":
feature_mask = pass_threshold.all(axis=1)
elif time_mode == "any":
feature_mask = pass_threshold.any(axis=1)
elif time_mode == "proportion":
if prop is None:
raise ValueError("prop must be set when time_mode is 'proportion'")
feature_mask = (pass_threshold.sum(axis=1) / pass_threshold.shape[0]) >= prop
else:
raise ValueError(f"Unknown time_mode: {time_mode}")

number_per_feature = counts.sum(axis=1).astype(np.float64)

n_filtered = int((~feature_mask).sum())

if n_filtered > 0:
msg = f"filtered out {n_filtered} features that are measured "
if threshold_min is not None:
msg += f"less than {threshold_min} counts"
else:
msg += f"more than {threshold_max} counts"

if time_mode == "proportion":
msg += f" in less than {prop * 100:.1f}% of time points"
else:
msg += f" in {time_mode} time points"
logger.info(msg)

label = "n_obs_over_time"
data.var[label] = number_per_feature
data._inplace_subset_var(feature_mask)
return data if copy else None


def filter_observations(
edata: EHRData,
*,
min_vars: int | None = None,
max_vars: int | None = None,
time_mode: Literal["all", "any", "proportion"] = "all",
prop: float | None = None,
copy: bool = False,
) -> EHRData | None:
"""Filter observations based on numbers of variables (features/measurements).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many comments from above also apply to this function :)


Keep only observations which have at least `min_vars` variables and/or at most `max_vars` variables.
When a longitudinal `EHRData` is passed, filtering can be done across time points.

Only provide `min_vars` and/or `max_vars` per call.

Args:
edata: Central data object.
min_vars: Minimum number of variables required for an observation to pass filtering.
max_vars: Maximum number of variables allowed for an observation to pass filtering.
time_mode: How to combine filtering criteria across the time axis. Only relevant if an `EHRData` is passed. Options are:
* `'all'` (default): The observation must pass the filtering criteria in all time points.
* `'any'`: The observation must pass the filtering criteria in at least one time point.
* `'proportion'`: The observation must pass the filtering criteria in at least a proportion `prop` of time points. For example, with `prop=0.3`,
the observation must pass the filtering criteria in at least 30% of the time points.
prop: Proportion of time points in which the observation must pass the filtering criteria. Only relevant if `time_mode='proportion'`.
copy: Determines whether a copy is returned.

Returns:
Depending on `copy`, subsets and annotates the passed data object and returns `None`

Examples:
>>> import ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6)
>>> edata.R.shape
(500, 45, 15)
>>> ep.pp.filter_observations(edata, min_vars=10, time_mode="all")
>>> edata.R.shape
(477, 45, 15)

"""
if not isinstance(edata, EHRData):
raise TypeError("Data object must be an EHRData object")

data = edata.copy() if copy else edata

lower_set = min_vars is not None
upper_set = max_vars is not None

if not (lower_set or upper_set):
raise ValueError("You must provide at least one of 'min_vars' and 'max_vars'")
if time_mode not in {"all", "any", "proportion"}:
raise ValueError(f"time_mode must be one of 'all', 'any', 'proportion', got {time_mode}")
if time_mode == "proportion" and (prop is None or not (0 < prop <= 1)):
raise ValueError("prop must be set to a value between 0 and 1 when time_mode is 'proportion'")

threshold_min = min_vars
threshold_max = max_vars

_obs_ax, var_ax, _time_ax = 0, 1, 2
n_obs, n_vars, n_time = edata.R.shape
per_time_vals = np.empty((n_obs, n_time), dtype=float)

for t in range(n_time):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the for loop can be vectorized, I hope the below suggestion is correct and please doublecheck :D

missing_mask = np.isin(data.R, MISSING_VALUES) | np.isnan(data.R)
per_time_vals = (~missing_mask).sum(axis=1)

sliced = data.R[:, :, t]

missing_mask = np.isin(sliced, MISSING_VALUES) | np.isnan(sliced)

present = ~missing_mask
vals = present.sum(axis=var_ax)

per_time_vals[:, t] = vals

if threshold_min is not None and threshold_max is not None:
masks_t = (per_time_vals >= float(threshold_min)) & (per_time_vals <= float(threshold_max))
elif threshold_min is not None:
masks_t = per_time_vals >= float(threshold_min)
elif threshold_max is not None:
masks_t = per_time_vals <= float(threshold_max)

if time_mode == "all":
obs_mask = masks_t.all(axis=1)
elif time_mode == "any":
obs_mask = masks_t.any(axis=1)
else:
obs_mask = masks_t.mean(axis=1) >= float(prop)

number_per_obs = per_time_vals.sum(axis=1).astype(np.float64)

n_filtered = int((~obs_mask).sum())
if n_filtered > 0:
msg = f"filtered out {n_filtered} observations that have"
if threshold_min is not None:
msg += f"less than {threshold_min} " + "features"
else:
msg += f"more than {threshold_max} " + "features"
if time_mode == "proportion":
msg += f" in < {prop * 100:.1f}% of time points"
else:
msg += f" in {time_mode} time points"

logger.info(msg)

label = "n_vars_over_time"
data.obs[label] = number_per_obs
data._inplace_subset_obs(obs_mask)
return data if copy else None


def _random_resample(
label: str | np.ndarray,
target: str = "balanced",
Expand Down
61 changes: 61 additions & 0 deletions tests/preprocessing/test_filter_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path

import ehrdata as ed
import numpy as np
import pytest

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent


def test_filter_features_invalid_args_min_max_obs():
edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When layers are allowed, check that this works for layers. You can check in other test cases how this can be done with pytest.mark.parametrize

Also, it make the test more robust if we check that this works in the 2D as well as in the 3D


# no threshold
with pytest.raises(ValueError):
ep.pp.filter_features(edata)
# invalid time_mode
with pytest.raises(ValueError):
ep.pp.filter_features(edata, min_obs=185, time_mode="invalid_mode", copy=False)
# invalid prop
with pytest.raises(ValueError):
ep.pp.filter_features(edata, min_obs=185, time_mode="proportion", prop=2, copy=False)

# min_obs filtering
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be split into a separate test - its not doing what the test name says it is doing

n_vars_before = edata.R.shape[1]
ep.pp.filter_features(edata, min_obs=185, time_mode="all", copy=False)
n_vars_after = edata.R.shape[1]
assert n_vars_after < n_vars_before

# max_obs filtering
n_vars_before = edata.R.shape[1]
ep.pp.filter_features(edata, max_obs=200, time_mode="all", copy=False)
n_vars_after = edata.R.shape[1]
assert n_vars_after < n_vars_before


def test_filter_obs_invalid_args_min_max_vars():
edata = ed.dt.ehrdata_blobs(n_variables=45, n_observations=500, base_timepoints=15, missing_values=0.6)

# no threshold
with pytest.raises(ValueError):
ep.pp.filter_observations(edata)
# invalid time_mode
with pytest.raises(ValueError):
ep.pp.filter_observations(edata, min_vars=10, time_mode="invalid_mode", copy=False)
# invalid prop
with pytest.raises(ValueError):
ep.pp.filter_observations(edata, min_vars=10, time_mode="proportion", prop=2, copy=False)

# min_vars filtering
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

n_obs_before = edata.R.shape[0]
ep.pp.filter_observations(edata, min_vars=10, time_mode="all", copy=False)
n_obs_after = edata.R.shape[0]
assert n_obs_after < n_obs_before

# max_vars filtering
n_obs_before = edata.R.shape[0]
ep.pp.filter_observations(edata, max_vars=12, time_mode="all", copy=False)
n_obs_after = edata.R.shape[0]
assert n_obs_after < n_obs_before
Loading