Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from composer.utils.checkpoint import load_checkpoint, save_checkpoint
from composer.utils.file_helpers import get_file
from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.tqdm_utils import monkeypatch_tqdm

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -741,6 +742,7 @@ def __init__(
# Profiling
profiler: Optional[Profiler] = None,
):
monkeypatch_tqdm()
algorithms = list(ensure_tuple(algorithms))

# Determine whether DeepSpeed is enabled
Expand Down
2 changes: 2 additions & 0 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from composer.trainer.trainer import Trainer
from composer.utils import dist, reproducibility
from composer.utils.object_store.object_store_hparams import ObjectStoreHparams, object_store_registry
from composer.utils.tqdm_utils import monkeypatch_tqdm

if TYPE_CHECKING:
from typing import TypedDict
Expand Down Expand Up @@ -406,6 +407,7 @@ def validate(self):

def initialize_object(self) -> Trainer:
self.validate()
monkeypatch_tqdm()

# Set the Python LogLevel for Composer
import composer
Expand Down
2 changes: 2 additions & 0 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.utils.object_store import LibcloudObjectStore, ObjectStore, ObjectStoreTransientError
from composer.utils.retrying import retry
from composer.utils.string_enum import StringEnum
from composer.utils.tqdm_utils import monkeypatch_tqdm

__all__ = [
'ensure_tuple',
Expand Down Expand Up @@ -40,4 +41,5 @@
'enable_env_report',
'print_env',
'retry',
'monkeypatch_tqdm',
]
57 changes: 57 additions & 0 deletions composer/utils/tqdm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Helpers to fix :mod:`tqdm` progress bars when streamed over the network."""

# Adapted from https://github.com/tqdm/tqdm/issues/1319#issuecomment-1100951505

import os
import sys

import tqdm.std
import tqdm.utils

__all__ = ['monkeypatch_tqdm']

_disp_len = tqdm.utils.disp_len
_unicode = tqdm.utils._unicode


def _should_printer_print_new_line():
in_kubernetes_env = os.environ.get('KUBERNETES_SERVICE_HOST') is not None
tqdm_printer_new_line_enabled = os.environ.get('TQDM_PRINTER_NEW_LINE', '').upper() in ('1', 'TRUE')
return in_kubernetes_env or tqdm_printer_new_line_enabled


def _new_status_printer(file):
"""Manage the printing and in-place updating of a line of characters.

Note that if the string is longer than a line, then in-place
updating may not work (it will print a new line at each refresh).
"""
fp = file
fp_flush = getattr(fp, 'flush', lambda: None) # pragma: no cover
if fp in (sys.stderr, sys.stdout):
getattr(sys.stderr, 'flush', lambda: None)()
getattr(sys.stdout, 'flush', lambda: None)()

def fp_write(s):
fp.write(_unicode(s))
fp_flush()

if _should_printer_print_new_line():
getattr(fp, 'write', lambda x: None)('\n')

last_len = [0]

def print_status(s):
len_s = _disp_len(s)
fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0)))
last_len[0] = len_s

return print_status


def monkeypatch_tqdm():
"""Monkeypatch the :meth:`tqdm.std.tqdm.status_printer` to work when being streamed over a network."""
tqdm.std.tqdm.status_printer = staticmethod(_new_status_printer)