15
15
from composer .core .time import Timestamp , TimeUnit
16
16
from composer .loggers .logger import Logger , LogLevel , format_log_data_value
17
17
from composer .loggers .logger_destination import LoggerDestination
18
- from composer .utils import dist
18
+ from composer .utils import dist , is_notebook
19
19
20
20
__all__ = ['ProgressBarLogger' ]
21
21
@@ -27,7 +27,7 @@ class _ProgressBar:
27
27
def __init__ (
28
28
self ,
29
29
total : Optional [int ],
30
- position : int ,
30
+ position : Optional [ int ] ,
31
31
bar_format : str ,
32
32
file : TextIO ,
33
33
metrics : Dict [str , Any ],
@@ -40,7 +40,7 @@ def __init__(
40
40
self .position = position
41
41
self .timestamp_key = timestamp_key
42
42
self .file = file
43
- is_atty = os .isatty (self .file .fileno ())
43
+ is_atty = is_notebook () or os .isatty (self .file .fileno ())
44
44
self .pbar = tqdm .auto .tqdm (
45
45
total = total ,
46
46
position = position ,
@@ -51,7 +51,8 @@ def __init__(
51
51
# We set `leave=False` so TQDM does not jump around, but we emulate `leave=True` behavior when closing
52
52
# by printing a dummy newline and refreshing to force tqdm to print to a stale line
53
53
# But on k8s, we need `leave=True`, as it would otherwise overwrite the pbar in place
54
- leave = not is_atty ,
54
+ # If in a notebook, then always set leave=True, as otherwise jupyter would remote the progress bars
55
+ leave = True if is_notebook () else not is_atty ,
55
56
postfix = metrics ,
56
57
unit = unit ,
57
58
)
@@ -67,18 +68,22 @@ def update(self, n=1):
67
68
def update_to_timestamp (self , timestamp : Timestamp ):
68
69
n = int (getattr (timestamp , self .timestamp_key ))
69
70
n = n - self .pbar .n
70
- self .pbar . update (int (n ))
71
+ self .update (int (n ))
71
72
72
73
def close (self ):
73
- if self . position != 0 :
74
- # Force a (potentially hidden) progress bar to re-render itself
75
- # Don't render the dummy pbar (at position 0), since that will clear a real pbar (at position 1)
74
+ if is_notebook () :
75
+ # If in a notebook, always refresh before closing, so the
76
+ # finished progress is displayed
76
77
self .pbar .refresh ()
77
- # Create a newline that will not be erased by leave=False. This allows for the finished pbar to be cached in the terminal
78
- # This emulates `leave=True` without progress bar jumping
79
- print ('' , file = self .file , flush = True )
80
-
81
- self .pbar .close ()
78
+ else :
79
+ if self .position != 0 :
80
+ # Force a (potentially hidden) progress bar to re-render itself
81
+ # Don't render the dummy pbar (at position 0), since that will clear a real pbar (at position 1)
82
+ self .pbar .refresh ()
83
+ # Create a newline that will not be erased by leave=False. This allows for the finished pbar to be cached in the terminal
84
+ # This emulates `leave=True` without progress bar jumping
85
+ print ('' , file = self .file , flush = True )
86
+ self .pbar .close ()
82
87
83
88
def state_dict (self ) -> Dict [str , Any ]:
84
89
pbar_state = self .pbar .format_dict
@@ -226,7 +231,8 @@ def _build_pbar(self, state: State, is_train: bool) -> _ProgressBar:
226
231
with the time (in units of ``max_duration.unit``) at which evaluation runs.
227
232
"""
228
233
# Always using position=1 to avoid jumping progress bars
229
- position = 1
234
+ # In jupyter notebooks, no need for the dummy pbar, so use the default position
235
+ position = None if is_notebook () else 1
230
236
desc = f'{ state .dataloader_label :15} '
231
237
max_duration_unit = None if state .max_duration is None else state .max_duration .unit
232
238
@@ -267,23 +273,27 @@ def _build_pbar(self, state: State, is_train: bool) -> _ProgressBar:
267
273
total = total ,
268
274
position = position ,
269
275
keys_to_log = _IS_TRAIN_TO_KEYS_TO_LOG [is_train ],
270
- bar_format = desc + ' {l_bar}{bar:25}{r_bar}{bar:-1b}' ,
276
+ # In a notebook, the `bar_format` should not include the {bar}, as otherwise
277
+ # it would appear twice.
278
+ bar_format = desc + ' {l_bar}' + ('' if is_notebook () else '{bar:25}' ) + '{r_bar}{bar:-1b}' ,
271
279
unit = unit .value .lower (),
272
280
metrics = {},
273
281
timestamp_key = timestamp_key ,
274
282
)
275
283
276
284
def init (self , state : State , logger : Logger ) -> None :
277
285
del state , logger # unused
278
- self .dummy_pbar = _ProgressBar (
279
- file = self .stream ,
280
- position = 0 ,
281
- total = 1 ,
282
- metrics = {},
283
- keys_to_log = [],
284
- bar_format = '{bar:-1b}' ,
285
- timestamp_key = '' ,
286
- )
286
+ if not is_notebook ():
287
+ # Notebooks don't need the dummy progress bar; otherwise, it would be visible.
288
+ self .dummy_pbar = _ProgressBar (
289
+ file = self .stream ,
290
+ position = 0 ,
291
+ total = 1 ,
292
+ metrics = {},
293
+ keys_to_log = [],
294
+ bar_format = '{bar:-1b}' ,
295
+ timestamp_key = '' ,
296
+ )
287
297
288
298
def epoch_start (self , state : State , logger : Logger ) -> None :
289
299
if self .show_pbar and not self .train_pbar :
0 commit comments