Skip to content

Commit c2ad6db

Browse files
hanlintA-Jacobson
authored andcommitted
Use tqdm.auto for notebooks (#298)
1 parent db39a7d commit c2ad6db

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

composer/loggers/tqdm_logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from dataclasses import asdict, dataclass
77
from typing import TYPE_CHECKING, Any, Dict, List, Optional
88

9-
import tqdm
109
import yaml
10+
from tqdm import auto
1111

1212
from composer.core.logging import LogLevel, TLogData, TLogDataValue, format_log_data_value
1313
from composer.core.logging.base_backend import BaseLoggerBackend
@@ -35,7 +35,7 @@ class _TQDMLoggerInstance:
3535

3636
def __init__(self, state: _TQDMLoggerInstanceState) -> None:
3737
self.state = state
38-
self.pbar = tqdm.tqdm(total=state.total,
38+
self.pbar = auto.tqdm(total=state.total,
3939
desc=state.description,
4040
position=state.position,
4141
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
@@ -65,12 +65,12 @@ class TQDMLoggerBackend(BaseLoggerBackend):
6565
6666
Example output::
6767
68-
Epoch 1: 100%|██████████| 64/64 [00:01<00:00, 53.17it/s, loss/train=2.3023]
69-
Epoch 1 (val): 100%|██████████| 20/20 [00:00<00:00, 100.96it/s, accuracy/val=0.0995]
68+
Epoch 1: 100%|██████████| 64/64 [00:01<00:00, 53.17it/s, loss/train=2.3023]
69+
Epoch 1 (val): 100%|██████████| 20/20 [00:00<00:00, 100.96it/s, accuracy/val=0.0995]
7070
7171
.. note::
7272
73-
It is currently not possible to show additional metrics.
73+
It is currently not possible to show additional metrics.
7474
Custom metrics for the TQDM progress bar will be supported in a future version.
7575
7676
Args:

tests/test_logger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from unittest.mock import MagicMock
66

77
import pytest
8-
import tqdm
98
from _pytest.monkeypatch import MonkeyPatch
9+
from tqdm import auto
1010

1111
from composer.core.event import Event
1212
from composer.core.logging import Logger, LogLevel
@@ -82,7 +82,8 @@ def get_mock_tqdm(position: int, *args, **kwargs):
8282
is_train_to_mock_tqdms[is_train].append(mock_tqdm)
8383
return mock_tqdm
8484

85-
monkeypatch.setattr(tqdm, "tqdm", get_mock_tqdm)
85+
monkeypatch.setattr(auto, "tqdm", get_mock_tqdm)
86+
8687
max_epochs = 2
8788
mosaic_trainer_hparams.max_duration = f"{max_epochs}ep"
8889
mosaic_trainer_hparams.loggers = [TQDMLoggerBackendHparams()]

0 commit comments

Comments
 (0)