You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/trainer/evaluation.rst
+23-1Lines changed: 23 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,6 +17,26 @@ specified by the the :class:`.Trainer` parameter ``eval_interval``.
17
17
)
18
18
19
19
The metrics should be provided by :meth:`.ComposerModel.metrics`.
20
+
For more information, see the "Metrics" section in :doc:`/composer_model`.
21
+
22
+
To provide a deeper intuition, here's pseudocode for the evaluation logic that occurs every ``eval_interval``:
23
+
24
+
.. code:: python
25
+
26
+
metrics = model.metrics(train=False)
27
+
28
+
for batch in eval_dataloader:
29
+
outputs, targets = model.validate(batch)
30
+
metrics.update(outputs, targets) # implements the torchmetrics interface
31
+
32
+
metrics.compute()
33
+
34
+
- The trainer iterates over ``eval_dataloader`` and passes each batch to the model's :meth:`.ComposerModel.validate` method.
35
+
- Outputs of ``model.validate`` are used to update ``metrics`` (a :class:`torchmetrics.Metric` or :class:`torchmetrics.MetricCollection` returned by :meth:`.ComposerModel.metrics <model.metrics(train=False)>`).
36
+
- Finally, metrics over the whole validation dataset are computed.
37
+
38
+
Note that the tuple returned by :meth:`.ComposerModel.validate` provide the positional arguments to ``metrics.update``.
39
+
Please keep this in mind when using custom models and/or metrics.
20
40
21
41
Multiple Datasets
22
42
-----------------
@@ -25,7 +45,7 @@ If there are multiple validation datasets that may have different metrics,
25
45
use :class:`.Evaluator` to specify each pair of dataloader and metrics.
26
46
This class is just a container for a few attributes:
27
47
28
-
- ``label``: a user-specified name for the metric.
48
+
- ``label``: a user-specified name for the evaluator.
29
49
- ``dataloader``: PyTorch :class:`~torch.utils.data.DataLoader` or our :class:`.DataSpec`.
30
50
See :doc:`DataLoaders</trainer/dataloaders>` for more details.
31
51
- ``metric_names``: list of names of metrics to track.
@@ -55,3 +75,5 @@ can be specified as in the following example:
55
75
eval_dataloader=[glue_mrpc_task, glue_mnli_task],
56
76
...
57
77
)
78
+
79
+
Note that `metric_names` must be a subset of the metrics provided by the model in :meth:`.ComposerModel.metrics`.
0 commit comments