Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ workflow{
// Integrate annotations back into pose files
// This branch requires files to be local and already url-ified
PREPARE_DATA(params.input_batch, params.location, true)
INTEGRATE_CORNER_ANNOTATIONS(PREPATE_DATA.out.file_processing_channel, params.sleap_file)
INTEGRATE_CORNER_ANNOTATIONS(PREPARE_DATA.out.file_processing_channel, params.sleap_file)
ADD_DUMMY_VIDEO(INTEGRATE_CORNER_ANNOTATIONS.out, params.clip_duration)
paired_video_and_pose = ADD_DUMMY_VIDEO.out[0]

Expand All @@ -50,7 +50,7 @@ workflow{
if (params.workflow == "single-mouse-v6-features"){
PREPARE_DATA(params.input_batch, params.location, false)
// Generate features from pose_v6 files
ADD_DUMMY_VIDEO(PREPARE_DATA.out.out_file, params.clip_duration)
ADD_DUMMY_VIDEO(PREPARE_DATA.out.file_processing_channel, params.clip_duration)
paired_video_and_pose = ADD_DUMMY_VIDEO.out[0]
SINGLE_MOUSE_V6_FEATURES(paired_video_and_pose)
}
Expand Down
12 changes: 6 additions & 6 deletions nextflow/configs/profiles/sumner2.config
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,16 @@ process {
* Runtime options
*/
withLabel: "tracking" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/deployment-runtime_2025-08-26.sif"
container = "/projects/kumar-lab/meta/images/mouse-tracking-runtime/runtime/v0.1.2/latest.sif"
}
withLabel: "jabs_classify" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/JABS-GUI_2025-02-12_v0.18.1.sif"
container = "/projects/kumar-lab/meta/images/JABS-behavior-classifier/headless/v0.36.1/latest.sif"
}
withLabel: "jabs_postprocess" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/JABS-Postprocessing-2025-03-27_864d687.sif"
container = "/projects/kumar-lab/meta/images/JABS-postprocess/jabs-postprocess/v0.3.1/latest.sif"
}
withLabel: "jabs_table_convert" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/support-r-code_2025-02-11.sif"
container = "/projects/kumar-lab/meta/images/mouse-tracking-runtime/RBase/v0.1.2/latest.sif"
}
withLabel: "gait" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/gait-pipeline-2025-03-27.sif"
Expand All @@ -210,10 +210,10 @@ process {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/vfi-2025-03-27.sif"
}
withLabel: "sleap" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/sleap-1.4.1.sif"
container = "/projects/kumar-lab/meta/images/mouse-tracking-runtime/sleap-1.4.1/v0.1.2/latest.sif"
}
withLabel: "sleap_io" {
container = "/projects/kumar-lab/multimouse-pipeline/nextflow-containers/sleap-io-0.2.0.sif"
container = "/projects/kumar-lab/meta/images/mouse-tracking-runtime/sleap-io-0.2.0/v0.1.2/latest.sif"
}
withLabel: "rclone" {
// executor.queueSize = 1
Expand Down
4 changes: 2 additions & 2 deletions nextflow/workflows/io.nf
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ workflow PREPARE_DATA {

// Files should be appropriately URLified to avoid collisions within the pipeline
if (skip_urlify) {
file_processing_channel = file_batch.readLines().flatMap { line -> file(line) }
file_processing_channel = file_batch.splitText().map { line -> file(line.trim()) }
} else {
file_processing_channel = URLIFY_FILE(file_batch.readLines().flatMap(), params.path_depth).file
file_processing_channel = URLIFY_FILE(file_batch.splitText().map { it.trim() }, params.path_depth).file
}

emit:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ gpu = [
"torch==2.6.0",
"torchvision==0.21.0",
"torchaudio==2.6.0",
"nvidia-cusparselt-cu12==0.6.3",
]

# CPU-only convenience for local tests (unchanged idea)
Expand Down
23 changes: 22 additions & 1 deletion src/mouse_tracking/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from rich import print

from mouse_tracking import __version__
from mouse_tracking.core.config.pose_utils import PoseUtilsConfig
from mouse_tracking.matching.match_predictions import match_predictions
from mouse_tracking.pose import render
from mouse_tracking.pose.convert import downgrade_pose_file
from mouse_tracking.utils import fecal_boli, static_objects
from mouse_tracking.utils.clip_video import clip_video_auto, clip_video_manual
from mouse_tracking.utils.writers import downgrade_pose_file, filter_large_poses

app = typer.Typer()
CONFIG = PoseUtilsConfig()


def version_callback(value: bool) -> None:
Expand Down Expand Up @@ -248,3 +250,22 @@ def stitch_tracklets(
This command stitches tracklets from the specified source.
"""
match_predictions(in_pose)


@app.command()
def filter_large_area_pose(
in_pose: Path = typer.Argument(..., help="Input HDF5 pose file"),
max_area: int = typer.Option(
CONFIG.OFA_MAX_EXPECTED_AREA_PX,
help="Maximum area a pose can have, using a bounding box on keypoint pose.",
),
):
"""
Filer pose by area.

This command unmarks identity of pose with large areas.
"""
filter_large_poses(
in_pose,
max_area,
)
2 changes: 1 addition & 1 deletion src/mouse_tracking/core/config/pose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ class PoseUtilsConfig(BaseSettings):
MIN_JABS_KEYPOINTS: int = 3

# Large animals are rarely larger than 100px in our OFA
OFA_MAX_EXPECTED_AREA_PX = 150 * 150
OFA_MAX_EXPECTED_AREA_PX: int = 150 * 150
56 changes: 0 additions & 56 deletions src/mouse_tracking/pose/convert.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
"""Pose data conversion utilities."""

import os
import re

import h5py
import numpy as np

from mouse_tracking.core.exceptions import InvalidPoseFileException
from mouse_tracking.utils.run_length_encode import run_length_encode
from mouse_tracking.utils.writers import write_pixel_per_cm_attr, write_pose_v2_data


def v2_to_v3(pose_data, conf_data, threshold: float = 0.3):
Expand Down Expand Up @@ -95,53 +89,3 @@ def multi_to_v2(pose_data, conf_data, identity_data):
return_list.append((cur_id, single_pose, single_conf))

return return_list


def downgrade_pose_file(pose_h5_path, disable_id: bool = False):
"""Downgrades a multi-mouse pose file into multiple single mouse pose files.

Args:
pose_h5_path: input pose file
disable_id: bool to disable identity embedding tracks (if available) and use tracklet data instead
"""
if not os.path.isfile(pose_h5_path):
raise FileNotFoundError(f"ERROR: missing file: {pose_h5_path}")
# Read in all the necessary data
with h5py.File(pose_h5_path, "r") as pose_h5:
if "version" in pose_h5["poseest"].attrs:
major_version = pose_h5["poseest"].attrs["version"][0]
else:
raise InvalidPoseFileException(
f"Pose file {pose_h5_path} did not have a valid version."
)
if major_version == 2:
print(f"Pose file {pose_h5_path} is already v2. Exiting.")
exit(0)

all_points = pose_h5["poseest/points"][:]
all_confidence = pose_h5["poseest/confidence"][:]
if major_version >= 4 and not disable_id:
all_track_id = pose_h5["poseest/instance_embed_id"][:]
elif major_version >= 3:
all_track_id = pose_h5["poseest/instance_track_id"][:]
try:
config_str = pose_h5["poseest/points"].attrs["config"]
model_str = pose_h5["poseest/points"].attrs["model"]
except (KeyError, AttributeError):
config_str = "unknown"
model_str = "unknown"
pose_attrs = pose_h5["poseest"].attrs
if "cm_per_pixel" in pose_attrs and "cm_per_pixel_source" in pose_attrs:
pixel_scaling = True
px_per_cm = pose_h5["poseest"].attrs["cm_per_pixel"]
source = pose_h5["poseest"].attrs["cm_per_pixel_source"]
else:
pixel_scaling = False

downgraded_pose_data = multi_to_v2(all_points, all_confidence, all_track_id)
new_file_base = re.sub("_pose_est_v[0-9]+\\.h5", "", pose_h5_path)
for animal_id, pose_data, conf_data in downgraded_pose_data:
out_fname = f"{new_file_base}_animal_{animal_id}_pose_est_v2.h5"
write_pose_v2_data(out_fname, pose_data, conf_data, config_str, model_str)
if pixel_scaling:
write_pixel_per_cm_attr(out_fname, px_per_cm, source)
94 changes: 93 additions & 1 deletion src/mouse_tracking/utils/writers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Functions related to saving data to pose files."""

import os
import re
from pathlib import Path

import h5py
import numpy as np

from mouse_tracking.core.exceptions import InvalidPoseFileException
from mouse_tracking.matching import hungarian_match_points_seg
from mouse_tracking.pose.convert import v2_to_v3
from mouse_tracking.pose.convert import multi_to_v2, v2_to_v3
from mouse_tracking.pose.inspect import get_pose_bounding_box


def promote_pose_data(pose_file, current_version: int, new_version: int):
Expand Down Expand Up @@ -588,3 +591,92 @@ def write_pose_clip(
for key, attrs in all_attrs.items():
for cur_attr, data in attrs.items():
out_f[key].attrs.create(cur_attr, data)


def downgrade_pose_file(pose_h5_path, disable_id: bool = False):
"""Downgrades a multi-mouse pose file into multiple single mouse pose files.

Args:
pose_h5_path: input pose file
disable_id: bool to disable identity embedding tracks (if available) and use tracklet data instead
"""
if not os.path.isfile(pose_h5_path):
raise FileNotFoundError(f"ERROR: missing file: {pose_h5_path}")
# Read in all the necessary data
with h5py.File(pose_h5_path, "r") as pose_h5:
if "version" in pose_h5["poseest"].attrs:
major_version = pose_h5["poseest"].attrs["version"][0]
else:
raise InvalidPoseFileException(
f"Pose file {pose_h5_path} did not have a valid version."
)
if major_version == 2:
print(f"Pose file {pose_h5_path} is already v2. Exiting.")
exit(0)

all_points = pose_h5["poseest/points"][:]
all_confidence = pose_h5["poseest/confidence"][:]
if major_version >= 4 and not disable_id:
all_track_id = pose_h5["poseest/instance_embed_id"][:]
elif major_version >= 3:
all_track_id = pose_h5["poseest/instance_track_id"][:]
try:
config_str = pose_h5["poseest/points"].attrs["config"]
model_str = pose_h5["poseest/points"].attrs["model"]
except (KeyError, AttributeError):
config_str = "unknown"
model_str = "unknown"
pose_attrs = pose_h5["poseest"].attrs
if "cm_per_pixel" in pose_attrs and "cm_per_pixel_source" in pose_attrs:
pixel_scaling = True
px_per_cm = pose_h5["poseest"].attrs["cm_per_pixel"]
source = pose_h5["poseest"].attrs["cm_per_pixel_source"]
else:
pixel_scaling = False

downgraded_pose_data = multi_to_v2(all_points, all_confidence, all_track_id)
new_file_base = re.sub("_pose_est_v[0-9]+\\.h5", "", pose_h5_path)
for animal_id, pose_data, conf_data in downgraded_pose_data:
out_fname = f"{new_file_base}_animal_{animal_id}_pose_est_v2.h5"
write_pose_v2_data(out_fname, pose_data, conf_data, config_str, model_str)
if pixel_scaling:
write_pixel_per_cm_attr(out_fname, px_per_cm, source)


def filter_large_poses(in_pose_f: str | Path, area_threshold: float):
"""Unmarks identity of poses that exceed area threshold.

Args:
in_pose_f: Input pose filename
area_threshold: maximum pose bounding box allowed

Raises:
InvalidPoseFileException if the pose file is not >= 4.
"""
with h5py.File(in_pose_f, "r") as f:
try:
current_version = f["poseest"].attrs["version"][0]
except (KeyError, AttributeError, IndexError):
InvalidPoseFileException("Pose file does not have a version.")
if current_version < 4:
raise InvalidPoseFileException(
f"Pose file {in_pose_f} is {current_version}. Filtering is only implemented for pose file versions > 4."
)

pose_data = f["poseest/points"][:]
pose_confidence = f["poseest/confidence"][:]
identity_data = f["poseest/instance_embed_id"][:]
pose_masks = f["poseest/id_mask"][:]

pose_boxes = get_pose_bounding_box(pose_data, pose_confidence)
pose_boxes = pose_boxes.astype(float)
pose_box_size = pose_boxes[:, :, 1] - pose_boxes[:, :, 0]
pose_box_area = pose_box_size[:, :, 0] * pose_box_size[:, :, 1]

identities_to_unassign = np.where(pose_box_area > area_threshold)
identity_data[identities_to_unassign] = 0
pose_masks[identities_to_unassign] = 1

with h5py.File(in_pose_f, "a") as f:
f["poseest/instance_embed_id"][:] = identity_data
f["poseest/id_mask"][:] = pose_masks
12 changes: 12 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions vm/rclone.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from: rclone/rclone
bootstrap: docker

%post
apk add --no-cache bash