Skip to content
24 changes: 23 additions & 1 deletion docs/source/trainer/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ specified by the the :class:`.Trainer` parameter ``eval_interval``.
)

The metrics should be provided by :meth:`.ComposerModel.metrics`.
For more information, see the "Metrics" section in :doc:`/composer_model`.

To provide a deeper intuition, here's pseudocode for the evaluation logic that occurs every ``eval_interval``:

.. code:: python

metrics = model.metrics(train=False)

for batch in eval_dataloader:
outputs, targets = model.validate(batch)
metrics.update(outputs, targets) # implements the torchmetrics interface

metrics.compute()

- The trainer iterates over ``eval_dataloader`` and passes each batch to the model's :meth:`.ComposerModel.validate` method.
- Outputs of ``model.validate`` are used to update ``metrics`` (a :class:`torchmetrics.Metric` or :class:`torchmetrics.MetricCollection` returned by :meth:`.ComposerModel.metrics <model.metrics(train=False)>`).
- Finally, metrics over the whole validation dataset are computed.

Note that the tuple returned by :meth:`.ComposerModel.validate` provide the positional arguments to ``metrics.update``.
Please keep this in mind when using custom models and/or metrics.

Multiple Datasets
-----------------
Expand All @@ -25,7 +45,7 @@ If there are multiple validation datasets that may have different metrics,
use :class:`.Evaluator` to specify each pair of dataloader and metrics.
This class is just a container for a few attributes:

- ``label``: a user-specified name for the metric.
- ``label``: a user-specified name for the evaluator.
- ``dataloader``: PyTorch :class:`~torch.utils.data.DataLoader` or our :class:`.DataSpec`.
See :doc:`DataLoaders</trainer/dataloaders>` for more details.
- ``metric_names``: list of names of metrics to track.
Expand Down Expand Up @@ -55,3 +75,5 @@ can be specified as in the following example:
eval_dataloader=[glue_mrpc_task, glue_mnli_task],
...
)

Note that `metric_names` must be a subset of the metrics provided by the model in :meth:`.ComposerModel.metrics`.