Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,22 +419,24 @@ def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> No
for k, v in metrics.items():
if any(fnmatch.fnmatch(k, pattern) for pattern in self.ignore_metrics):
continue

v_float = float(v)
if k in self._metrics_cache:
value, last_step = self._metrics_cache[k]
if value == v and step < last_step + self.log_duplicated_metric_every_n_steps:
if value == v_float and step < last_step + self.log_duplicated_metric_every_n_steps:
# Skip logging the metric if it has the same value as the last step and it's
# within the step window.
continue
else:
# Log the metric if it has a different value or it's outside the step window,
# and update the metrics cache.
self._metrics_cache[k] = (v, step)
metrics_to_log[self.rename(k)] = float(v)
self._metrics_cache[k] = (v_float, step)
metrics_to_log[self.rename(k)] = v_float
else:
# Log the metric if it's the first time it's being logged, and update the metrics
# cache.
self._metrics_cache[k] = (v, step)
metrics_to_log[self.rename(k)] = float(v)
self._metrics_cache[k] = (v_float, step)
metrics_to_log[self.rename(k)] = v_float

log_metrics(
metrics=metrics_to_log,
Expand Down
19 changes: 19 additions & 0 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import pytest
import torch
import yaml
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -46,6 +47,24 @@ def _get_latest_mlflow_run(experiment_name, tracking_uri=None):
raise ValueError(f'Experiment with name {experiment_name} is unexpectedly empty')


@pytest.mark.gpu
def test_metrics_cache_cpu(
tmp_path: Path,
):
logger = MLFlowLogger(
tracking_uri=tmp_path / Path('my-test-mlflow-uri'),
)

metric_tensor = torch.tensor(1.0, device='cuda')
metrics_dict = {'test_metric': metric_tensor}

logger.log_metrics(metrics_dict)

for v in logger._metrics_cache.values():
for item in v:
assert not isinstance(item, torch.Tensor)


def test_mlflow_init_unspecified(monkeypatch):
""" Test that MLFlow experiment is set up correctly when no parameters are specified

Expand Down
Loading