Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d39ccfc
Initial Mult Eval Implementation
anisehsani Nov 30, 2021
b513ced
merged dev
anisehsani Nov 30, 2021
233c557
bug fixes
anisehsani Nov 30, 2021
37f243d
Merge branch 'dev' into mult_eval
anisehsani Nov 30, 2021
7c10389
Changing tests to include evaluators
anisehsani Dec 1, 2021
d64c635
Test fixes for Evaluators
anisehsani Dec 1, 2021
d8f976d
Format and bug fix
anisehsani Dec 1, 2021
63402a6
Comment and None assertions
anisehsani Dec 1, 2021
6952fdd
Bug - checkpoint testing fix
anisehsani Dec 1, 2021
5891b95
small fix
anisehsani Dec 1, 2021
25c4e4b
temp commit with changes
anisehsani Dec 8, 2021
1e5e2e1
Merge branch 'dev' into mult_eval
anisehsani Dec 14, 2021
4bf6ec6
Small compatibility fixes
anisehsani Dec 14, 2021
7fb20ae
compatibility fixes
anisehsani Jan 10, 2022
f8a88de
Merge branch 'dev' into mult_eval
anisehsani Jan 10, 2022
029c879
bug fixes
anisehsani Jan 10, 2022
fbc9be7
temp_fixes
anisehsani Jan 21, 2022
848fe7e
Merge branch 'dev' into mult_eval
anisehsani Jan 21, 2022
c49fbb2
bug_fixes
anisehsani Jan 21, 2022
f863a5e
merge bug fix
anisehsani Jan 21, 2022
eff2c00
Merge branch 'dev' into mult_eval
hanlint Jan 21, 2022
8f60718
bug_fixes
anisehsani Jan 24, 2022
09f9490
temp
anisehsani Jan 26, 2022
223a99e
temp
anisehsani Jan 26, 2022
bfd62ed
Merge branch 'temp_mult_eval' into mult_eval
anisehsani Jan 27, 2022
5ed1656
formatting
anisehsani Jan 27, 2022
2beffac
Merge branch 'dev' into mult_eval
anisehsani Jan 28, 2022
7e605ae
Evaluator cleanup
ravi-mosaicml Jan 31, 2022
f2b65b7
linting errors
anisehsani Jan 31, 2022
e085b54
evaluator Cleanup
anisehsani Jan 31, 2022
2ba4a8c
Merge branch 'mult_eval' of https://github.com/anisehsani/composer in…
anisehsani Jan 31, 2022
eda93e9
linting/type fixes
anisehsani Feb 1, 2022
c6bdf81
Merge branch 'dev' into mult_eval
anisehsani Feb 1, 2022
e6c8da6
bug fix
anisehsani Feb 1, 2022
92c9ed8
bug_fix/eval_batch_size
anisehsani Feb 1, 2022
aad1144
docs/redundant checks
anisehsani Feb 1, 2022
7daafa9
Merge branch 'dev' into mult_eval
hanlint Feb 1, 2022
e1f8a39
removed eval_batch_size and dataspec dict
anisehsani Feb 1, 2022
c1f4e01
Merge branch 'mult_eval' of https://github.com/anisehsani/composer in…
anisehsani Feb 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Union

from torchmetrics import Metric, MetricCollection

from composer.core.data_spec import DataSpec as DataSpec

if TYPE_CHECKING:
from composer.core.types import DataLoader, Metrics


class Evaluator:
"""Wrapper for a dataloader to include metrics that apply to a specific
dataset.

Attributes:
label (str): Name of the Evaluator
dataloader (Union[DataSpec, DataLoader]): Dataloader/DataSpec for evaluation data
metrics (Metrics): Metrics to log. The metrics will be deep-copied to ensure that
each evaluator updates only its metrics.
"""

def __init__(self, *, label: str, dataloader: Union[DataSpec, DataLoader], metrics: Metrics):
self.label = label
if isinstance(dataloader, DataSpec):
self.dataloader = dataloader
else:
self.dataloader = DataSpec(dataloader)

# Forcing metrics to be a MetricCollection simplifies logging results
metrics = copy.deepcopy(metrics)
if isinstance(metrics, Metric):
self.metrics = MetricCollection([metrics])
else:
self.metrics = metrics
10 changes: 5 additions & 5 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"batch_num_tokens",
"outputs",
"train_dataloader",
"eval_dataloader",
"evaluators",
"_steps_per_epoch",
"_precision_context",
"profiler",
Expand All @@ -79,8 +79,8 @@ class State(Serializable):
grad_accum (int): The number of gradient accumulation steps to use. The size of each microbatch is ``train_batch_size / num_gpus / grad_accum``.
train_dataloader (types.DataLoader, types.DataSpec, or dict):
The :class:`types.DataLoader`, :class:`types.DataSpec`, or dict of :class:`types.DataSpec` kwargs to used for training.
eval_dataloader (types.DataLoader, types.DataSpec, or dict):
The :class:`types.DataLoader`, :class:`types.DataSpec`, or dict of :class:`types.DataSpec` kwargs to used for evaluation.
evaluators (Evaluators):
The :class:`types.Evaluators` contain the evaluation datasets used for evaluation with specific metrics.
max_duration (str or Time): The maximum duration to train for.

precision (str | Precision): The numerical precision to use for training. Should be one of ``[fp32, amp]``.
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
# data configurations
grad_accum: int,
train_dataloader: types.DataLoader,
eval_dataloader: types.DataLoader,
evaluators: types.Evaluators,

# stopping conditions
max_duration: Union[str, Time[int]],
Expand All @@ -148,7 +148,7 @@ def __init__(
self.model = model
self.grad_accum = grad_accum
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.evaluators = list(ensure_tuple(evaluators))
self.max_duration = max_duration
self.steps_per_epoch = steps_per_epoch

Expand Down
3 changes: 2 additions & 1 deletion composer/core/types.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from composer.core.algorithm import Algorithm as Algorithm
from composer.core.data_spec import DataSpec as DataSpec
from composer.core.evaluator import Evaluator as Evaluator
from composer.core.event import Event as Event
from composer.core.logging import Logger as Logger
from composer.core.precision import Precision as Precision
Expand Down Expand Up @@ -141,8 +142,8 @@ def __len__(self) -> int:
...


Evaluators = Union[Evaluator, List[Evaluator], Tuple[Evaluator, ...]]
Metrics = Union[Metric, MetricCollection]

Optimizer = torch.optim.Optimizer
Optimizers = Union[Optimizer, Tuple[Optimizer, ...], List[Optimizer]]
Scheduler = torch.optim.lr_scheduler._LRScheduler
Expand Down
2 changes: 2 additions & 0 deletions composer/datasets/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
from composer.datasets.dataset_registry import get_dataset_registry as get_dataset_registry
from composer.datasets.evaluator import EvaluatorHparams as EvaluatorHparams
from composer.datasets.glue import GLUEHparams as GLUEHparams
from composer.datasets.hparams import DatasetHparams as DatasetHparams
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
Expand Down
23 changes: 23 additions & 0 deletions composer/datasets/dataset_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer.datasets.ade20k import ADE20kDatasetHparams
from composer.datasets.brats import BratsDatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams
from composer.datasets.glue import GLUEHparams
from composer.datasets.imagenet import ImagenetDatasetHparams
from composer.datasets.lm_datasets import LMDatasetHparams
from composer.datasets.mnist import MNISTDatasetHparams

registry = {
"ade20k": ADE20kDatasetHparams,
"brats": BratsDatasetHparams,
"imagenet": ImagenetDatasetHparams,
"cifar10": CIFAR10DatasetHparams,
"mnist": MNISTDatasetHparams,
"lm": LMDatasetHparams,
"glue": GLUEHparams
}


def get_dataset_registry():
return registry
85 changes: 85 additions & 0 deletions composer/datasets/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from __future__ import annotations

import copy
import logging
import textwrap
from dataclasses import dataclass
from typing import List, Optional

import yahp as hp
from torchmetrics import Metric, MetricCollection

from composer.core.types import Evaluator
from composer.datasets import DataloaderHparams
from composer.datasets.dataset_registry import get_dataset_registry
from composer.datasets.hparams import DatasetHparams
from composer.models.base import BaseMosaicModel

log = logging.getLogger(__name__)


@dataclass
class EvaluatorHparams(hp.Hparams):
"""Params for the :class:`Evaluator`.

See the documentation for the :class:`Evaluator`.
"""
hparams_registry = { # type: ignore
"eval_dataset": get_dataset_registry(),
}

label: str = hp.required(doc="Name of the Evaluator object. Used for logging/reporting metrics")
eval_dataset: DatasetHparams = hp.required(doc="Evaluator dataset for the Evaluator")
metric_names: Optional[List[str]] = hp.optional(
doc=textwrap.dedent("""Name of the metrics for the evaluator. Can be a torchmetrics metric name or the
class name of a metric returned by model.metrics(). If None (the default), uses all metrics in the model"""),
default=None)

def initialize_object(self, model: BaseMosaicModel, batch_size: int, dataloader_hparams: DataloaderHparams):
"""Initialize an :class:`Evaluator`

If the Evaluatormetric_names is empty or None is provided, the function returns
a copy of all the model's default evaluation metrics.

Args:
model (BaseMosaicModel): The model, which is used to retrieve metric names
batch_size (int): The device batch size to use for the evaluation dataset
dataloader_hparams (DataloaderHparams): The hparams to use to construct a dataloader for the evaluation dataset

Returns:
Evaluator: The evaluator
"""
dataloader = self.eval_dataset.initialize_object(batch_size=batch_size, dataloader_hparams=dataloader_hparams)

# Get and copy all the model's associated evaluation metrics
model_metrics = model.metrics(train=False)
if isinstance(model_metrics, Metric):
# Forcing metrics to be a MetricCollection simplifies logging results
model_metrics = MetricCollection([model_metrics])

# Use all the metrics from the model if no metric_names are specified
if self.metric_names is None:
evaluator_metrics = copy.deepcopy(model_metrics)
else:
evaluator_metrics = MetricCollection([])
for metric_name in self.metric_names:
try:
metric = model_metrics[metric_name]
except KeyError as e:
raise RuntimeError(
textwrap.dedent(f"""No metric found with the name {metric_name}. Check if this"
"metric is compatible/listed in your model metrics.""")) from e
assert isinstance(metric, Metric), "all values of a MetricCollection.__getitem__ should be a metric"
evaluator_metrics.add_metrics(copy.deepcopy(metric))
if len(evaluator_metrics) == 0:
raise RuntimeError(
textwrap.dedent(f"""No metrics compatible with your model were added to this evaluator.
Check that the metrics you specified are compatible/listed in your model."""))

return Evaluator(
label=self.label,
dataloader=dataloader,
metrics=evaluator_metrics,
)
Empty file modified composer/datasets/hparams.py
100644 → 100755
Empty file.
5 changes: 4 additions & 1 deletion composer/loggers/logger_hparams.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def get_flattened_dict(data: Dict[str, Any], _prefix: List[str] = []) -> Dict[st
if isinstance(item, dict):
found_sub_dicts = True
for sub_key, sub_val in item.items():
all_items.update(get_flattened_dict(sub_val, key_items + [sub_key]))
if isinstance(sub_val, dict):
all_items.update(get_flattened_dict(sub_val, key_items + [sub_key]))
else:
all_items.update({sub_key: sub_val})
if not found_sub_dicts:
all_items[key_name] = val
else:
Expand Down
13 changes: 10 additions & 3 deletions composer/loggers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import collections.abc
import sys
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
Expand Down Expand Up @@ -108,9 +109,15 @@ def _start(self, state: State):
if dist.get_global_rank() != 0:
return
assert self.is_train is not None, "self.is_train should be set by the callback"
# TODO(anis) -- in #120, len(state.eval_dataloader) is inaccurate, as it does not incorporate
# trainer._eval_subset_num_batches. The evaluator spec should fix this.
total_steps = state.steps_per_epoch if self.is_train else len(state.eval_dataloader)
if self.is_train:
total_steps = state.steps_per_epoch
else:
total_steps = 0
for evaluator in state.evaluators:
dataloader_spec = evaluator.dataloader
assert isinstance(dataloader_spec.dataloader, collections.abc.Sized)
total_steps += len(dataloader_spec.dataloader)

desc = f'Epoch {int(state.timer.epoch)}'
position = 0 if self.is_train else 1
if not self.is_train:
Expand Down
4 changes: 3 additions & 1 deletion composer/profiler/dataloader_profiler.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ def init(self, state: State, logger: Logger):
textwrap.dedent("""To use the dataloader profiler, state.profiler must be set.
Make sure to run composer with the profiler -- i.e. with the `--profiler` CLI flag."""))
state.train_dataloader = ProfiledDataLoader(state.profiler, state.train_dataloader, "train")
state.eval_dataloader = ProfiledDataLoader(state.profiler, state.eval_dataloader, "eval")
for evaluator in state.evaluators:
evaluator.dataloader.dataloader = ProfiledDataLoader(state.profiler, evaluator.dataloader.dataloader,
evaluator.label)
Empty file modified composer/profiler/torch_profiler.py
100644 → 100755
Empty file.
Loading