You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: composer/models/bert/model.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -173,7 +173,7 @@ def create_bert_classification(num_labels: int = 2,
173
173
Second, the returned :class:`.ComposerModel`'s train/validation metrics will be :class:`~torchmetrics.MeanSquaredError` and :class:`~torchmetrics.SpearmanCorrCoef`.
174
174
175
175
For the classification case (when ``num_labels > 1``), the training loss is :class:`~torch.nn.CrossEntropyLoss`, and the train/validation
176
-
metrics are :class:`~torchmetrics.Accuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
176
+
metrics are :class:`~torchmetrics.MulticlassAccuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
Copy file name to clipboardExpand all lines: docs/source/notes/early_stopping.rst
+9-9Lines changed: 9 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,7 +16,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
16
16
from composer.callbacks.early_stopper import EarlyStopper
17
17
18
18
early_stopper = EarlyStopper(
19
-
monitor='Accuracy',
19
+
monitor='MulticlassAccuracy',
20
20
dataloader_label='train',
21
21
patience='50ba',
22
22
comp=torch.greater,
@@ -32,7 +32,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
32
32
max_duration="1ep",
33
33
)
34
34
35
-
In the above example, the ``'train'`` label means the callback is tracking the ``Accuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.
35
+
In the above example, the ``'train'`` label means the callback is tracking the ``MulticlassAccuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.
36
36
37
37
We also set ``patience='50ba'`` and ``min_delta=0.01`` which means that every 50 batches, if the Accuracy does not exceed the best recorded Accuracy by ``0.01``, training is stopped. The ``comp`` argument indicates that 'better' here means higher accuracy. Note that the ``patience`` parameter can take both a time string (see :doc:`Time</trainer/time>`) or an integer which specifies a number of epochs.
38
38
@@ -53,8 +53,8 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
53
53
from composer.callbacks.threshold_stopper import ThresholdStopper
54
54
55
55
threshold_stopper = ThresholdStopper(
56
-
monitor="Accuracy",
57
-
dataloader_label="eval",
56
+
monitor='MulticlassAccuracy',
57
+
dataloader_label='eval',
58
58
threshold=0.8,
59
59
)
60
60
@@ -64,7 +64,7 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
64
64
eval_dataloader=eval_dataloader,
65
65
optimizers=optimizer,
66
66
callbacks=[threshold_stopper],
67
-
max_duration="1ep",
67
+
max_duration='1ep',
68
68
)
69
69
70
70
In this example, training will exit when the model's validation accuracy exceeds 0.8. For a full list of arguments, see the documentation for :class:`.ThresholdStopper.`
@@ -76,7 +76,7 @@ When there are multiple datasets and metrics to use for validation and evaluatio
76
76
77
77
Each Evaluator object is marked with a ``label`` field for logging, and a ``metric_names`` field that accepts a list of metric names. These can be provided to the callbacks above to indiciate which metric to monitor.
78
78
79
-
In the example below, the callback will monitor the `Accuracy` metric in the dataloader marked `eval_dataset1`.`
79
+
In the example below, the callback will monitor the `MulticlassAccuracy` metric in the dataloader marked `eval_dataset1`.`
80
80
81
81
.. testsetup::
82
82
@@ -90,17 +90,17 @@ In the example below, the callback will monitor the `Accuracy` metric in the dat
0 commit comments