Skip to content

Logging a TorchMetric resets it, produces very subtle and unexpected behaviour #11262

@nils-werner

Description

@nils-werner

🐛 Bug

The way I understand it, it is recommended to log a TorchMetric instance, instead of its .compute() output (only if you do this, DDP synchronization is possible, should you want it):

def __init__(self):
    self.val_acc = Accuracy()
    ...

def validation_step_end(self, outputs):
    self.log("val/acc", self.val_acc)

However I was suprised to learn that this also quietly seems to reset the metric, even if the metric should not be reset:

def __init__(self):
    self.val_acc = Accuracy()
    self.val_acc_best = MaxMetric()
    ...

def validation_epoch_end(self, outputs):
    self.val_acc_best.update(self.val_acc.compute())
    self.log("val/acc_best", self.val_acc_best)
    # after this line, self.val_acc_best.value = -Inf

Imagine my surprise that my "best accuracy" was decreasing every few epochs.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy, MaxMetric

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_acc = Accuracy()
        self.val_acc_best = MaxMetric()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        val_loss = self(batch).sum()
        self.log("valid_loss", val_loss)
        return {"loss": val_loss}

    def validation_step_end(self, outputs):
        print(self.val_acc_best.value)
        self.val_acc_best.update(outputs['loss']) # get val accuracy from current epoch
        print(self.val_acc_best.value)
        self.log("valid_loss_best", self.val_acc_best)
        print(self.val_acc_best.value)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=2,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

As you will see in the output, self.val_acc_best.value will be -Inf in the second epoch, and therefore fails to track the best loss.

Expected behavior

Environment

* CUDA:
        - GPU:
                - NVIDIA GeForce GTX 1060 6GB
        - available:         True
        - version:           10.2
* Packages:
        - numpy:             1.20.3
        - pyTorch_debug:     False
        - pyTorch_version:   1.10.0+cu102
        - pytorch-lightning: 1.6.0dev
        - tqdm:              4.62.3
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         
        - python:            3.8.12
        - version:           #1 ZEN SMP PREEMPT Wed, 22 Dec 2021 09:23:53 +0000

cc @Borda @carmocca @edward-io @ananthsub @rohitgr7

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementloggingRelated to the `LoggerConnector` and `log()`

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions