Skip to content

Commit bc58b79

Browse files
Fix Notebook Progress Bars (#1313)
#1264 broke the progress bars in notebooks. It screwed up the formatting and caused an io.UnsupportedOperation error in Colab when calling sys.stderr.fileno(). This PR fixes these issues. Closes #1312 Closes https://mosaicml.atlassian.net/browse/CO-770
1 parent 6a69007 commit bc58b79

File tree

4 files changed

+50
-31
lines changed

4 files changed

+50
-31
lines changed

composer/loggers/progress_bar_logger.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from composer.core.time import Timestamp, TimeUnit
1616
from composer.loggers.logger import Logger, LogLevel, format_log_data_value
1717
from composer.loggers.logger_destination import LoggerDestination
18-
from composer.utils import dist
18+
from composer.utils import dist, is_notebook
1919

2020
__all__ = ['ProgressBarLogger']
2121

@@ -27,7 +27,7 @@ class _ProgressBar:
2727
def __init__(
2828
self,
2929
total: Optional[int],
30-
position: int,
30+
position: Optional[int],
3131
bar_format: str,
3232
file: TextIO,
3333
metrics: Dict[str, Any],
@@ -40,7 +40,7 @@ def __init__(
4040
self.position = position
4141
self.timestamp_key = timestamp_key
4242
self.file = file
43-
is_atty = os.isatty(self.file.fileno())
43+
is_atty = is_notebook() or os.isatty(self.file.fileno())
4444
self.pbar = tqdm.auto.tqdm(
4545
total=total,
4646
position=position,
@@ -51,7 +51,8 @@ def __init__(
5151
# We set `leave=False` so TQDM does not jump around, but we emulate `leave=True` behavior when closing
5252
# by printing a dummy newline and refreshing to force tqdm to print to a stale line
5353
# But on k8s, we need `leave=True`, as it would otherwise overwrite the pbar in place
54-
leave=not is_atty,
54+
# If in a notebook, then always set leave=True, as otherwise jupyter would remote the progress bars
55+
leave=True if is_notebook() else not is_atty,
5556
postfix=metrics,
5657
unit=unit,
5758
)
@@ -67,18 +68,22 @@ def update(self, n=1):
6768
def update_to_timestamp(self, timestamp: Timestamp):
6869
n = int(getattr(timestamp, self.timestamp_key))
6970
n = n - self.pbar.n
70-
self.pbar.update(int(n))
71+
self.update(int(n))
7172

7273
def close(self):
73-
if self.position != 0:
74-
# Force a (potentially hidden) progress bar to re-render itself
75-
# Don't render the dummy pbar (at position 0), since that will clear a real pbar (at position 1)
74+
if is_notebook():
75+
# If in a notebook, always refresh before closing, so the
76+
# finished progress is displayed
7677
self.pbar.refresh()
77-
# Create a newline that will not be erased by leave=False. This allows for the finished pbar to be cached in the terminal
78-
# This emulates `leave=True` without progress bar jumping
79-
print('', file=self.file, flush=True)
80-
81-
self.pbar.close()
78+
else:
79+
if self.position != 0:
80+
# Force a (potentially hidden) progress bar to re-render itself
81+
# Don't render the dummy pbar (at position 0), since that will clear a real pbar (at position 1)
82+
self.pbar.refresh()
83+
# Create a newline that will not be erased by leave=False. This allows for the finished pbar to be cached in the terminal
84+
# This emulates `leave=True` without progress bar jumping
85+
print('', file=self.file, flush=True)
86+
self.pbar.close()
8287

8388
def state_dict(self) -> Dict[str, Any]:
8489
pbar_state = self.pbar.format_dict
@@ -226,7 +231,8 @@ def _build_pbar(self, state: State, is_train: bool) -> _ProgressBar:
226231
with the time (in units of ``max_duration.unit``) at which evaluation runs.
227232
"""
228233
# Always using position=1 to avoid jumping progress bars
229-
position = 1
234+
# In jupyter notebooks, no need for the dummy pbar, so use the default position
235+
position = None if is_notebook() else 1
230236
desc = f'{state.dataloader_label:15}'
231237
max_duration_unit = None if state.max_duration is None else state.max_duration.unit
232238

@@ -267,23 +273,27 @@ def _build_pbar(self, state: State, is_train: bool) -> _ProgressBar:
267273
total=total,
268274
position=position,
269275
keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[is_train],
270-
bar_format=desc + ' {l_bar}{bar:25}{r_bar}{bar:-1b}',
276+
# In a notebook, the `bar_format` should not include the {bar}, as otherwise
277+
# it would appear twice.
278+
bar_format=desc + ' {l_bar}' + ('' if is_notebook() else '{bar:25}') + '{r_bar}{bar:-1b}',
271279
unit=unit.value.lower(),
272280
metrics={},
273281
timestamp_key=timestamp_key,
274282
)
275283

276284
def init(self, state: State, logger: Logger) -> None:
277285
del state, logger # unused
278-
self.dummy_pbar = _ProgressBar(
279-
file=self.stream,
280-
position=0,
281-
total=1,
282-
metrics={},
283-
keys_to_log=[],
284-
bar_format='{bar:-1b}',
285-
timestamp_key='',
286-
)
286+
if not is_notebook():
287+
# Notebooks don't need the dummy progress bar; otherwise, it would be visible.
288+
self.dummy_pbar = _ProgressBar(
289+
file=self.stream,
290+
position=0,
291+
total=1,
292+
metrics={},
293+
keys_to_log=[],
294+
bar_format='{bar:-1b}',
295+
timestamp_key='',
296+
)
287297

288298
def epoch_start(self, state: State, logger: Logger) -> None:
289299
if self.show_pbar and not self.train_pbar:

composer/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from composer.utils.import_helpers import MissingConditionalImportError, import_object
1212
from composer.utils.inference import export_for_inference
1313
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
14-
from composer.utils.misc import is_model_deepspeed
14+
from composer.utils.misc import is_model_deepspeed, is_notebook
1515
from composer.utils.object_store import (LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, S3ObjectStore,
1616
SFTPObjectStore)
1717
from composer.utils.retrying import retry
@@ -31,6 +31,7 @@
3131
'MissingConditionalImportError',
3232
'import_object',
3333
'is_model_deepspeed',
34+
'is_notebook',
3435
'StringEnum',
3536
'load_checkpoint',
3637
'save_checkpoint',

composer/utils/collect_env.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import importlib_metadata
5050
import psutil
5151

52+
from composer.utils.misc import is_notebook
53+
5254
__all__ = ['configure_excepthook', 'disable_env_report', 'enable_env_report', 'print_env']
5355

5456
# Check if PyTorch is installed
@@ -70,14 +72,11 @@
7072
COMPOSER_AVAILABLE = False
7173

7274
# Check if we're running in a notebook
73-
try:
74-
__IPYTHON__ #type: ignore
75+
IPYTHON_AVAILABLE = is_notebook()
76+
if IPYTHON_AVAILABLE:
7577
from composer.utils.import_helpers import import_object
7678
get_ipython = import_object('IPython:get_ipython')
7779
nb = get_ipython()
78-
IPYTHON_AVAILABLE = True
79-
except (NameError,):
80-
IPYTHON_AVAILABLE = False
8180

8281
# Place to keep track of the original excepthook
8382
_orig_excepthook = None

composer/utils/misc.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
__all__ = ['is_model_deepspeed']
8+
__all__ = ['is_model_deepspeed', 'is_notebook']
99

1010

1111
def is_model_deepspeed(model: torch.nn.Module) -> bool:
@@ -16,3 +16,12 @@ def is_model_deepspeed(model: torch.nn.Module) -> bool:
1616
return False
1717
else:
1818
return isinstance(model, deepspeed.DeepSpeedEngine)
19+
20+
21+
def is_notebook():
22+
"""Whether Composer is running in a IPython/Jupyter Notebook."""
23+
try:
24+
__IPYTHON__ #type: ignore
25+
return True
26+
except NameError:
27+
return False

0 commit comments

Comments
 (0)