Skip to content

Commit c6f2c93

Browse files
Complete upgrade of torchmetrics accuracy (#2025)
Replace all instances of Accuracy with MulticlassAccuracy --------- Co-authored-by: nik-mosaic <None> Co-authored-by: Mihir Patel <[email protected]>
1 parent 28bf919 commit c6f2c93

File tree

14 files changed

+41
-41
lines changed

14 files changed

+41
-41
lines changed

composer/callbacks/early_stopper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ class EarlyStopper(Callback):
2727
>>> from composer import Evaluator, Trainer
2828
>>> from composer.callbacks.early_stopper import EarlyStopper
2929
>>> # constructing trainer object with this callback
30-
>>> early_stopper = EarlyStopper("Accuracy", "my_evaluator", patience=1)
30+
>>> early_stopper = EarlyStopper('MulticlassAccuracy', 'my_evaluator', patience=1)
3131
>>> evaluator = Evaluator(
3232
... dataloader = eval_dataloader,
3333
... label = 'my_evaluator',
34-
... metric_names = ['Accuracy']
34+
... metric_names = ['MulticlassAccuracy']
3535
... )
3636
>>> trainer = Trainer(
3737
... model=model,

composer/callbacks/mlperf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class MLPerfCallback(Callback):
8282
callback = MLPerfCallback(
8383
root_folder='/submission',
8484
index=0,
85-
metric_name='Accuracy',
85+
metric_name='MulticlassAccuracy',
8686
metric_label='eval',
8787
target='0.759',
8888
)
@@ -113,7 +113,7 @@ class MLPerfCallback(Callback):
113113
division (str, optional): Division of submission. Currently only ``open`` division supported.
114114
Default: ``'open'``.
115115
metric_name (str, optional): name of the metric to compare against the target.
116-
Default: ``Accuracy``.
116+
Default: ``MulticlassAccuracy``.
117117
metric_label (str, optional): The label name. The metric will be accessed via
118118
``state.eval_metrics[metric_label][metric_name]``.
119119
submitter (str, optional): Submitting organization. Default: ``"MosaicML"``.
@@ -135,7 +135,7 @@ def __init__(
135135
benchmark: str = 'resnet',
136136
target: float = 0.759,
137137
division: str = 'open',
138-
metric_name: str = 'Accuracy',
138+
metric_name: str = 'MulticlassAccuracy',
139139
metric_label: str = 'eval',
140140
submitter: str = 'MosaicML',
141141
system_name: Optional[str] = None,

composer/callbacks/threshold_stopper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ class ThresholdStopper(Callback):
1717
Example:
1818
.. doctest::
1919
20+
>>> from composer import Evaluator, Trainer
2021
>>> from composer.callbacks.threshold_stopper import ThresholdStopper
21-
>>> from torchmetrics.classification.accuracy import Accuracy
2222
>>> # constructing trainer object with this callback
23-
>>> threshold_stopper = ThresholdStopper("Accuracy", "my_evaluator", 0.7)
23+
>>> threshold_stopper = ThresholdStopper('MulticlassAccuracy', 'my_evaluator', 0.7)
2424
>>> evaluator = Evaluator(
2525
... dataloader = eval_dataloader,
2626
... label = 'my_evaluator',
27-
... metric_names = ['Accuracy']
27+
... metric_names = ['MulticlassAccuracy']
2828
... )
2929
>>> trainer = Trainer(
3030
... model=model,

composer/core/evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,24 +99,24 @@ class Evaluator:
9999
.. doctest::
100100
101101
>>> eval_evaluator = Evaluator(
102-
... label="myEvaluator",
102+
... label='myEvaluator',
103103
... dataloader=eval_dataloader,
104-
... metric_names=['Accuracy']
104+
... metric_names=['MulticlassAccuracy']
105105
... )
106106
>>> trainer = Trainer(
107107
... model=model,
108108
... train_dataloader=train_dataloader,
109109
... eval_dataloader=eval_evaluator,
110110
... optimizers=optimizer,
111-
... max_duration="1ep",
111+
... max_duration='1ep',
112112
... )
113113
114114
Args:
115115
label (str): Name of the Evaluator.
116116
dataloader (DataSpec | Iterable | Dict[str, Any]): Iterable that yields batches, a :class:`.DataSpec`
117117
for evaluation, or a Dict of :class:`.DataSpec` kwargs.
118118
metric_names: The list of metric names to compute.
119-
Each value in this list can be a regex string (e.g. "Accuracy", "f1" for "BinaryF1Score",
119+
Each value in this list can be a regex string (e.g. "MulticlassAccuracy", "f1" for "BinaryF1Score",
120120
"Top-." for "Top-1", "Top-2", etc). Each regex string will be matched against the keys of the dictionary returned
121121
by ``model.get_metrics()``. All matching metrics will be evaluated.
122122

composer/core/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ class State(Serializable):
271271
... ...,
272272
... train_dataloader=train_dataloader,
273273
... eval_dataloader=[
274-
... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['Accuracy']),
275-
... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['Accuracy']),
274+
... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['MulticlassAccuracy']),
275+
... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['MulticlassAccuracy']),
276276
... ],
277277
... )
278278
>>> trainer.fit()

composer/models/bert/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def create_bert_classification(num_labels: int = 2,
173173
Second, the returned :class:`.ComposerModel`'s train/validation metrics will be :class:`~torchmetrics.MeanSquaredError` and :class:`~torchmetrics.SpearmanCorrCoef`.
174174
175175
For the classification case (when ``num_labels > 1``), the training loss is :class:`~torch.nn.CrossEntropyLoss`, and the train/validation
176-
metrics are :class:`~torchmetrics.Accuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
176+
metrics are :class:`~torchmetrics.MulticlassAccuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
177177
"""
178178
try:
179179
import transformers

composer/models/classify_mnist/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
_dataset = 'MNIST'
1111
_name = 'SimpleConvNet'
1212
_quality = ''
13-
_metric = 'Accuracy'
13+
_metric = 'MulticlassAccuracy'
1414
_ttt = '?'
1515
_hparams = 'classify_mnist_cpu.yaml'

composer/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,13 +2535,13 @@ def eval(
25352535
glue_mrpc_task = Evaluator(
25362536
label='glue_mrpc',
25372537
dataloader=mrpc_dataloader,
2538-
metric_names=['BinaryF1Score', 'Accuracy']
2538+
metric_names=['BinaryF1Score', 'MulticlassAccuracy']
25392539
)
25402540
25412541
glue_mnli_task = Evaluator(
25422542
label='glue_mnli',
25432543
dataloader=mnli_dataloader,
2544-
metric_names=['Accuracy']
2544+
metric_names=['MulticlassAccuracy']
25452545
)
25462546
25472547
trainer = Trainer(

docs/source/composer_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ A full example of a validation implementation would be:
152152
153153
def get_metrics(self, is_train=False):
154154
# defines which metrics to use in each phase of training
155-
return {'Accuracy': self.train_accuracy} if train else {'Accuracy': self.val_accuracy}
155+
return {'MulticlassAccuracy': self.train_accuracy} if train else {'MulticlassAccuracy': self.val_accuracy}
156156
157157
.. note::
158158

docs/source/notes/early_stopping.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
1616
from composer.callbacks.early_stopper import EarlyStopper
1717

1818
early_stopper = EarlyStopper(
19-
monitor='Accuracy',
19+
monitor='MulticlassAccuracy',
2020
dataloader_label='train',
2121
patience='50ba',
2222
comp=torch.greater,
@@ -32,7 +32,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
3232
max_duration="1ep",
3333
)
3434

35-
In the above example, the ``'train'`` label means the callback is tracking the ``Accuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.
35+
In the above example, the ``'train'`` label means the callback is tracking the ``MulticlassAccuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.
3636

3737
We also set ``patience='50ba'`` and ``min_delta=0.01`` which means that every 50 batches, if the Accuracy does not exceed the best recorded Accuracy by ``0.01``, training is stopped. The ``comp`` argument indicates that 'better' here means higher accuracy. Note that the ``patience`` parameter can take both a time string (see :doc:`Time</trainer/time>`) or an integer which specifies a number of epochs.
3838

@@ -53,8 +53,8 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
5353
from composer.callbacks.threshold_stopper import ThresholdStopper
5454

5555
threshold_stopper = ThresholdStopper(
56-
monitor="Accuracy",
57-
dataloader_label="eval",
56+
monitor='MulticlassAccuracy',
57+
dataloader_label='eval',
5858
threshold=0.8,
5959
)
6060

@@ -64,7 +64,7 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
6464
eval_dataloader=eval_dataloader,
6565
optimizers=optimizer,
6666
callbacks=[threshold_stopper],
67-
max_duration="1ep",
67+
max_duration='1ep',
6868
)
6969

7070
In this example, training will exit when the model's validation accuracy exceeds 0.8. For a full list of arguments, see the documentation for :class:`.ThresholdStopper.`
@@ -76,7 +76,7 @@ When there are multiple datasets and metrics to use for validation and evaluatio
7676

7777
Each Evaluator object is marked with a ``label`` field for logging, and a ``metric_names`` field that accepts a list of metric names. These can be provided to the callbacks above to indiciate which metric to monitor.
7878

79-
In the example below, the callback will monitor the `Accuracy` metric in the dataloader marked `eval_dataset1`.`
79+
In the example below, the callback will monitor the `MulticlassAccuracy` metric in the dataloader marked `eval_dataset1`.`
8080

8181
.. testsetup::
8282

@@ -90,17 +90,17 @@ In the example below, the callback will monitor the `Accuracy` metric in the dat
9090
evaluator1 = Evaluator(
9191
label='eval_dataset1',
9292
dataloader=eval_dataloader,
93-
metric_names=['Accuracy']
93+
metric_names=['MulticlassAccuracy']
9494
)
9595

9696
evaluator2 = Evaluator(
9797
label='eval_dataset2',
9898
dataloader=eval_dataloader2,
99-
metric_names=['Accuracy']
99+
metric_names=['MulticlassAccuracy']
100100
)
101101

102102
early_stopper = EarlyStopper(
103-
monitor='Accuracy',
103+
monitor='MulticlassAccuracy',
104104
dataloader_label='eval_dataset1',
105105
patience=1
106106
)

0 commit comments

Comments
 (0)