diff --git a/docs/source/trainer/evaluation.rst b/docs/source/trainer/evaluation.rst index b004423fd4..8db1f024bd 100644 --- a/docs/source/trainer/evaluation.rst +++ b/docs/source/trainer/evaluation.rst @@ -17,6 +17,26 @@ specified by the the :class:`.Trainer` parameter ``eval_interval``. ) The metrics should be provided by :meth:`.ComposerModel.metrics`. +For more information, see the "Metrics" section in :doc:`/composer_model`. + +To provide a deeper intuition, here's pseudocode for the evaluation logic that occurs every ``eval_interval``: + +.. code:: python + + metrics = model.metrics(train=False) + + for batch in eval_dataloader: + outputs, targets = model.validate(batch) + metrics.update(outputs, targets) # implements the torchmetrics interface + + metrics.compute() + +- The trainer iterates over ``eval_dataloader`` and passes each batch to the model's :meth:`.ComposerModel.validate` method. +- Outputs of ``model.validate`` are used to update ``metrics`` (a :class:`torchmetrics.Metric` or :class:`torchmetrics.MetricCollection` returned by :meth:`.ComposerModel.metrics `). +- Finally, metrics over the whole validation dataset are computed. + +Note that the tuple returned by :meth:`.ComposerModel.validate` provide the positional arguments to ``metrics.update``. +Please keep this in mind when using custom models and/or metrics. Multiple Datasets ----------------- @@ -25,7 +45,7 @@ If there are multiple validation datasets that may have different metrics, use :class:`.Evaluator` to specify each pair of dataloader and metrics. This class is just a container for a few attributes: -- ``label``: a user-specified name for the metric. +- ``label``: a user-specified name for the evaluator. - ``dataloader``: PyTorch :class:`~torch.utils.data.DataLoader` or our :class:`.DataSpec`. See :doc:`DataLoaders` for more details. - ``metric_names``: list of names of metrics to track. @@ -55,3 +75,5 @@ can be specified as in the following example: eval_dataloader=[glue_mrpc_task, glue_mnli_task], ... ) + +Note that `metric_names` must be a subset of the metrics provided by the model in :meth:`.ComposerModel.metrics`.