-
Notifications
You must be signed in to change notification settings - Fork 454
Multiple Evaluator Datasets #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
d39ccfc
Initial Mult Eval Implementation
anisehsani b513ced
merged dev
anisehsani 233c557
bug fixes
anisehsani 37f243d
Merge branch 'dev' into mult_eval
anisehsani 7c10389
Changing tests to include evaluators
anisehsani d64c635
Test fixes for Evaluators
anisehsani d8f976d
Format and bug fix
anisehsani 63402a6
Comment and None assertions
anisehsani 6952fdd
Bug - checkpoint testing fix
anisehsani 5891b95
small fix
anisehsani 25c4e4b
temp commit with changes
anisehsani 1e5e2e1
Merge branch 'dev' into mult_eval
anisehsani 4bf6ec6
Small compatibility fixes
anisehsani 7fb20ae
compatibility fixes
anisehsani f8a88de
Merge branch 'dev' into mult_eval
anisehsani 029c879
bug fixes
anisehsani fbc9be7
temp_fixes
anisehsani 848fe7e
Merge branch 'dev' into mult_eval
anisehsani c49fbb2
bug_fixes
anisehsani f863a5e
merge bug fix
anisehsani eff2c00
Merge branch 'dev' into mult_eval
hanlint 8f60718
bug_fixes
anisehsani 09f9490
temp
anisehsani 223a99e
temp
anisehsani bfd62ed
Merge branch 'temp_mult_eval' into mult_eval
anisehsani 5ed1656
formatting
anisehsani 2beffac
Merge branch 'dev' into mult_eval
anisehsani 7e605ae
Evaluator cleanup
ravi-mosaicml f2b65b7
linting errors
anisehsani e085b54
evaluator Cleanup
anisehsani 2ba4a8c
Merge branch 'mult_eval' of https://github.com/anisehsani/composer in…
anisehsani eda93e9
linting/type fixes
anisehsani c6bdf81
Merge branch 'dev' into mult_eval
anisehsani e6c8da6
bug fix
anisehsani 92c9ed8
bug_fix/eval_batch_size
anisehsani aad1144
docs/redundant checks
anisehsani 7daafa9
Merge branch 'dev' into mult_eval
hanlint e1f8a39
removed eval_batch_size and dataspec dict
anisehsani c1f4e01
Merge branch 'mult_eval' of https://github.com/anisehsani/composer in…
anisehsani File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from __future__ import annotations | ||
|
||
import copy | ||
from typing import TYPE_CHECKING, Union | ||
|
||
from torchmetrics import Metric, MetricCollection | ||
|
||
from composer.core.data_spec import DataSpec as DataSpec | ||
|
||
if TYPE_CHECKING: | ||
from composer.core.types import DataLoader, Metrics | ||
|
||
|
||
class Evaluator: | ||
"""Wrapper for a dataloader to include metrics that apply to a specific | ||
dataset. | ||
|
||
Attributes: | ||
label (str): Name of the Evaluator | ||
dataloader (Union[DataSpec, DataLoader]): Dataloader/DataSpec for evaluation data | ||
metrics (Metrics): Metrics to log. The metrics will be deep-copied to ensure that | ||
each evaluator updates only its metrics. | ||
""" | ||
|
||
def __init__(self, *, label: str, dataloader: Union[DataSpec, DataLoader], metrics: Metrics): | ||
self.label = label | ||
if isinstance(dataloader, DataSpec): | ||
self.dataloader = dataloader | ||
else: | ||
self.dataloader = DataSpec(dataloader) | ||
|
||
# Forcing metrics to be a MetricCollection simplifies logging results | ||
metrics = copy.deepcopy(metrics) | ||
if isinstance(metrics, Metric): | ||
self.metrics = MetricCollection([metrics]) | ||
else: | ||
self.metrics = metrics |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
|
||
from composer.datasets.ade20k import ADE20kDatasetHparams | ||
from composer.datasets.brats import BratsDatasetHparams | ||
from composer.datasets.cifar10 import CIFAR10DatasetHparams | ||
from composer.datasets.glue import GLUEHparams | ||
from composer.datasets.imagenet import ImagenetDatasetHparams | ||
from composer.datasets.lm_datasets import LMDatasetHparams | ||
from composer.datasets.mnist import MNISTDatasetHparams | ||
|
||
registry = { | ||
"ade20k": ADE20kDatasetHparams, | ||
"brats": BratsDatasetHparams, | ||
"imagenet": ImagenetDatasetHparams, | ||
"cifar10": CIFAR10DatasetHparams, | ||
"mnist": MNISTDatasetHparams, | ||
"lm": LMDatasetHparams, | ||
"glue": GLUEHparams | ||
} | ||
|
||
|
||
def get_dataset_registry(): | ||
return registry |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import logging | ||
import textwrap | ||
from dataclasses import dataclass | ||
from typing import List, Optional | ||
|
||
import yahp as hp | ||
from torchmetrics import Metric, MetricCollection | ||
|
||
from composer.core.types import Evaluator | ||
from composer.datasets import DataloaderHparams | ||
from composer.datasets.dataset_registry import get_dataset_registry | ||
from composer.datasets.hparams import DatasetHparams | ||
from composer.models.base import BaseMosaicModel | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class EvaluatorHparams(hp.Hparams): | ||
"""Params for the :class:`Evaluator`. | ||
|
||
See the documentation for the :class:`Evaluator`. | ||
""" | ||
hparams_registry = { # type: ignore | ||
anisehsani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"eval_dataset": get_dataset_registry(), | ||
} | ||
|
||
label: str = hp.required(doc="Name of the Evaluator object. Used for logging/reporting metrics") | ||
eval_dataset: DatasetHparams = hp.required(doc="Evaluator dataset for the Evaluator") | ||
metric_names: Optional[List[str]] = hp.optional( | ||
ravi-mosaicml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
doc=textwrap.dedent("""Name of the metrics for the evaluator. Can be a torchmetrics metric name or the | ||
class name of a metric returned by model.metrics(). If None (the default), uses all metrics in the model"""), | ||
default=None) | ||
eval_batch_size: Optional[int] = hp.optional( | ||
hanlint marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
doc="batch size to use for each evaluation step", | ||
default=None, | ||
) | ||
|
||
def initialize_object(self, model: BaseMosaicModel, batch_size: int, dataloader_hparams: DataloaderHparams): | ||
hanlint marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Initialize an :class:`Evaluator` | ||
|
||
If the Evaluatormetric_names is empty or None is provided, the function returns | ||
a copy of all the model's default evaluation metrics. | ||
|
||
Args: | ||
model (BaseMosaicModel): The model, which is used to retrieve metric names | ||
batch_size (int): The device batch size to use for the evaluation dataset | ||
dataloader_hparams (DataloaderHparams): The hparams to use to construct a dataloader for the evaluation dataset | ||
|
||
Returns: | ||
Evaluator: The evaluator | ||
""" | ||
evaluator_batch_size = self.eval_batch_size if self.eval_batch_size is not None else batch_size | ||
hanlint marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
dataloader = self.eval_dataset.initialize_object(batch_size=evaluator_batch_size, | ||
dataloader_hparams=dataloader_hparams) | ||
|
||
# Get and copy all the model's associated evaluation metrics | ||
model_metrics = model.metrics(train=False) | ||
if isinstance(model_metrics, Metric): | ||
# Forcing metrics to be a MetricCollection simplifies logging results | ||
model_metrics = MetricCollection([model_metrics]) | ||
|
||
# Use all the metrics from the model if no metric_names are specified | ||
if self.metric_names is None: | ||
evaluator_metrics = copy.deepcopy(model_metrics) | ||
else: | ||
evaluator_metrics = MetricCollection([]) | ||
for metric_name in self.metric_names: | ||
try: | ||
metric = model_metrics[metric_name] | ||
except KeyError as e: | ||
raise RuntimeError( | ||
textwrap.dedent(f"""No metric found with the name {metric_name}. Check if this" | ||
"metric is compatible/listed in your model metrics.""")) from e | ||
assert isinstance(metric, Metric), "all values of a MetricCollection.__getitem__ should be a metric" | ||
evaluator_metrics.add_metrics(copy.deepcopy(metric)) | ||
if len(evaluator_metrics) == 0: | ||
raise RuntimeError( | ||
textwrap.dedent(f"""No metrics compatible with your model were added to this evaluator. | ||
Check that the metrics you specified are compatible/listed in your model.""")) | ||
|
||
return Evaluator( | ||
label=self.label, | ||
dataloader=dataloader, | ||
hanlint marked this conversation as resolved.
Show resolved
Hide resolved
|
||
metrics=evaluator_metrics, | ||
) |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.