Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b3a0ac2
Refactor loggers.
eracah Aug 16, 2022
fcd2e2e
Change:
eracah Aug 16, 2022
c82e027
update progress bar logger
eracah Aug 16, 2022
6abb9c7
remove log levels from engine
eracah Aug 16, 2022
56a51ca
Fix tests
eracah Aug 16, 2022
61ae68f
pre-commit fixes
eracah Aug 16, 2022
21f75e4
cleaned up current_metrics test
eracah Aug 17, 2022
fd03501
add fixtures
eracah Aug 17, 2022
8111c7a
remove import
eracah Aug 17, 2022
adb4465
isoooort!
eracah Aug 17, 2022
83d9913
update docs
eracah Aug 17, 2022
d70c61a
Merge branch 'dev' of https://github.com/mosaicml/composer into final…
eracah Aug 18, 2022
1f4f562
Add back in self._enabled
eracah Aug 18, 2022
32862f0
Add notimplementederror to old logging functions.
eracah Aug 18, 2022
b79eae9
Add log_traces to progressbarlogger for console logging
eracah Aug 18, 2022
582c9da
pre-commit fixes
eracah Aug 18, 2022
36660f8
Make a sacrifice for the pyright gods.
eracah Aug 18, 2022
0d314fe
Last pyright fix
eracah Aug 18, 2022
0ea4bea
fix args indent
eracah Aug 19, 2022
61eb658
Merge branch 'dev' of https://github.com/mosaicml/composer into final…
eracah Aug 19, 2022
6485238
Move pyright to be done last for pre-commit, so pyright errors
eracah Aug 22, 2022
4415bea
Manually fix wandb_logger imports.
eracah Aug 22, 2022
aa612fe
delete blank line
eracah Aug 22, 2022
aa3978f
Fix wandb logging bug
eracah Aug 22, 2022
f2cd6c5
fix doctest failures
eracah Aug 22, 2022
c3e0e64
ci trigger
eracah Aug 22, 2022
a952294
delete trigger
eracah Aug 22, 2022
d4b3425
trigger ci
eracah Aug 22, 2022
701af7d
Merge branch 'dev' of https://github.com/mosaicml/composer into final…
eracah Aug 22, 2022
72a50bf
fix pre-commit
eracah Aug 22, 2022
cf053ae
add automatic step setting
eracah Aug 22, 2022
f9fe90c
fix doctest
eracah Aug 22, 2022
51b0d4e
add step argument to log_metrics call
eracah Aug 23, 2022
b6a8f0e
fix test_current_metrics test
eracah Aug 23, 2022
6a78ee6
pre-commit
eracah Aug 23, 2022
17fc12f
Set isort to treat wandb as a third party import for
eracah Aug 23, 2022
713135e
potential metrics test fix
eracah Aug 23, 2022
3180c6f
pre-commit
eracah Aug 23, 2022
beeef92
reworked test_current_metrics test
eracah Aug 23, 2022
5cb7a60
Merge branch 'dev' into final-log-refac
mvpatel2000 Aug 23, 2022
3c8a34e
Merge branch 'dev' of https://github.com/mosaicml/composer into final…
eracah Aug 26, 2022
2bd99e7
remove step
eracah Aug 26, 2022
e38adf0
pre-commit corrections.
eracah Aug 26, 2022
e0ce9f5
remove loglevels
eracah Aug 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ repos:
# types: [python]
# require_serial: true
# rev: v2.12.2
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: node
types: [python]
pass_filenames: false
args: [--warnings]
additional_dependencies: ["[email protected]"]
- repo: https://github.com/PyCQA/pydocstyle
hooks:
- id: pydocstyle
Expand Down Expand Up @@ -103,5 +93,15 @@ repos:
metadata.kernelspec metadata.language_info.version
cell.metadata.heading_collapsed metadata.name metadata.nbconvert_exporter
metadata.version metadata.vscode
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: node
types: [python]
pass_filenames: false
args: [--warnings]
additional_dependencies: ["[email protected]"]

exclude: .ci\/release_tests\/.*
2 changes: 1 addition & 1 deletion composer/algorithms/blurpool/blurpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _log_results(self, event: Event, state: State, logger: Logger) -> None:
f'Model now has {num_blurpool_layers} BlurMaxPool2d '
f'and {num_blurconv_layers} BlurConv2D layers.')

logger.data_fit({
logger.log_hyperparameters({
'blurpool/num_blurpool_layers': num_blurpool_layers,
'blurpool/num_blurconv_layers': num_blurconv_layers,
})
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/factorize/factorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:

if self.factorize_convs:
num_factorized = module_surgery.count_module_instances(state.model, FactorizedConv2d)
logger.data_fit({
logger.log_hyperparameters({
LOG_NUM_CONV2D_REPLACEMENTS_KEY: num_factorized,
})
if self.factorize_linears:
num_factorized = module_surgery.count_module_instances(state.model, FactorizedLinear)
logger.data_fit({
logger.log_hyperparameters({
LOG_NUM_LINEAR_REPLACEMENTS_KEY: num_factorized,
})

Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _log_results(self, event: Event, state: State, logger: Optional[Logger] = No
f'Model now has {num_new_modules} {module_name} modules')

if logger is not None:
logger.data_fit({
logger.log_hyperparameters({
f'{classname}/num_new_modules': num_new_modules,
})

Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/layer_freezing/layer_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
freeze_start=self.freeze_start,
freeze_level=self.freeze_level,
)
logger.data_epoch({
logger.log_metrics({
'layer_freezing/layers_frozen': freeze_depth,
'layer_freezing/percentage_frozen': freeze_percentage
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) ->
state.batch_set_item(self.target_key, new_target)

if logger is not None:
logger.data_batch({
logger.log_metrics({
'progressive_resizing/height': new_input.shape[2],
'progressive_resizing/width': new_input.shape[3],
'progressive_resizing/scale_factor': scale_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
state.batch = set_batch_sequence_length(state.batch, curr_seq_len, self.truncate, self.preserve_end_of_sequence)

batch_size = state.batch['input_ids'].shape[0]
logger.data_batch({
logger.log_metrics({
'seq_length_warmup/curr_seq_len': curr_seq_len,
'seq_length_warmup/curr_bs': batch_size,
})
2 changes: 1 addition & 1 deletion composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,6 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
f'min_channels={self.min_channels}. '
f'Model now has {layer_count} SqueezeExcite layers.')

logger.data_fit({
logger.log_hyperparameters({
'squeeze_excite/num_squeeze_excite_layers': layer_count,
})
4 changes: 2 additions & 2 deletions composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
drop_rate=self.drop_rate,
drop_distribution=self.drop_distribution)
self.num_stochastic_layers = module_surgery.count_module_instances(state.model, target_block)
logger.data_epoch({'stochastic_depth/num_stochastic_layers': self.num_stochastic_layers})
logger.log_metrics({'stochastic_depth/num_stochastic_layers': self.num_stochastic_layers})

elif event == Event.BATCH_START and self.num_stochastic_layers:
elapsed_duration = state.get_elapsed_duration()
Expand All @@ -197,7 +197,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
module_count=self.num_stochastic_layers)
else:
current_drop_rate = self.drop_rate
logger.data_batch({'stochastic_depth/drop_rate': current_drop_rate})
logger.log_metrics({'stochastic_depth/drop_rate': current_drop_rate})


def _validate_stochastic_hparams(target_layer_name: str,
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/grad_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def after_train_batch(self, state: State, logger: Logger):
norm += param_grad_norm

norm = norm**0.5
logger.data_batch({'grad_l2_norm/step': norm})
logger.log_metrics({'grad_l2_norm/step': norm})
if self.log_layer_grad_norms:
logger.data_batch(layer_norms)
logger.log_metrics(layer_norms)
6 changes: 3 additions & 3 deletions composer/callbacks/image_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image

from composer.core import Callback, State, Time, TimeUnit
from composer.loggers import InMemoryLogger, Logger, LogLevel, WandBLogger
from composer.loggers import InMemoryLogger, Logger, WandBLogger
from composer.loss.utils import infer_target_type
from composer.utils import ensure_tuple
from composer.utils.import_helpers import MissingConditionalImportError
Expand Down Expand Up @@ -123,7 +123,7 @@ def _log_inputs(self, state: State, logger: Logger, data_name: str):
# Only log to the wandb logger if it is available
for destination in ensure_tuple(logger.destinations):
if isinstance(destination, WandBLogger) or isinstance(destination, InMemoryLogger):
destination.log_data(state, LogLevel.BATCH, {data_name: table})
destination.log_metrics({data_name: table})

def _log_segmented_inputs(self, state: State, logger: Logger, data_name: str):
inputs = state.batch_get_item(key=self.input_key)
Expand All @@ -144,7 +144,7 @@ def _log_segmented_inputs(self, state: State, logger: Logger, data_name: str):
# Only log to the wandb logger if it is available
for destination in ensure_tuple(logger.destinations):
if isinstance(destination, WandBLogger) or isinstance(destination, InMemoryLogger):
destination.log_data(state, LogLevel.BATCH, {data_name: table})
destination.log_metrics({data_name: table})

def before_forward(self, state: State, logger: Logger):
if self.mode.lower() == 'input' and state.timestamp.get(self.interval.unit).value % self.interval.value == 0:
Expand Down
2 changes: 1 addition & 1 deletion composer/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def batch_end(self, state: State, logger: Logger):
name = optimizer.__class__.__name__
for lr in lrs:
for idx, lr in enumerate(lrs):
logger.data_batch({f'lr-{name}/group{idx}': lr})
logger.log_metrics({f'lr-{name}/group{idx}': lr})
2 changes: 1 addition & 1 deletion composer/callbacks/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def after_train_batch(self, state: State, logger: Logger):

memory_report = _get_memory_report()

logger.data_batch({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})
logger.log_metrics({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})


_MEMORY_STATS = {
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def batch_end(self, state: State, logger: Logger):
# Log the throughput
if len(self.batch_num_samples_buffer) == self.window_size:
throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer)
logger.data_batch({'throughput/samples_per_sec': throughput})
logger.log_metrics({'throughput/samples_per_sec': throughput})

# Log the time
# `state.timestamp` excludes any time spent in evaluation
logger.data_batch({
logger.log_metrics({
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
'wall_clock/val': self.total_eval_wct,
'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct),
Expand Down
14 changes: 3 additions & 11 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run_last(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorith
from composer.core.callback import Callback
from composer.core.event import Event
from composer.core.state import State
from composer.loggers import Logger, LogLevel
from composer.loggers import Logger
from composer.profiler import ProfilerAction

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -334,17 +334,9 @@ def _run_algorithms(
run=True)

if self.logger is not None:
if event in (Event.INIT, Event.FIT_START):
log_level = LogLevel.FIT
if event in (Event.EPOCH_START, Event.EPOCH_END):
log_level = LogLevel.EPOCH
else:
# algs don't run on eval events, so don't have to worry about
# batch-frequency vs epoch-frequency evaluators
log_level = LogLevel.BATCH
if len(trace) > 0:
self.logger.data(log_level=log_level,
data={f'{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()})
self.logger.log_traces(
{f'algorithm_traces/{tr.name}/{tr.event}': 1 if tr.run else 0 for _, tr in trace.items()})

return trace

Expand Down
83 changes: 31 additions & 52 deletions composer/loggers/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class FileLogger(LoggerDestination): # noqa: D101
file_logger = FileLogger(
filename="{{run_name}}/logs-rank{{rank}}.txt",
buffer_size=1,
log_level=LogLevel.BATCH,
log_interval=2,
flush_interval=50
)
trainer = Trainer(
Expand Down Expand Up @@ -99,20 +97,9 @@ class FileLogger(LoggerDestination): # noqa: D101
capture_stderr (bool, optional): Whether to include the ``stderr``in ``filename``. (default: ``True``)
buffer_size (int, optional): Buffer size. See :py:func:`open`.
Default: ``1`` for line buffering.
log_level (LogLevel, optional):
:class:`~.logger.LogLevel` (i.e. unit of resolution) at
which to record. Default: :attr:`~.LogLevel.EPOCH`.
log_interval (int, optional):
Frequency to print logs. If ``log_level`` is :attr:`~.LogLevel.EPOCH`,
logs will only be recorded every n epochs. If ``log_level`` is
:attr:`~.LogLevel.BATCH`, logs will be printed every n batches. Otherwise, if
``log_level`` is :attr:`~.LogLevel.FIT`, this parameter is ignored, as calls
at the :attr:`~.LogLevel.FIT` log level are always recorded. Default: ``1``.
flush_interval (int, optional): How frequently to flush the log to the file,
relative to the ``log_level``. For example, if the ``log_level`` is
:attr:`~.LogLevel.EPOCH`, then the logfile will be flushed every n epochs. If
the ``log_level`` is :attr:`~.LogLevel.BATCH`, then the logfile will be
flushed every n batches. Default: ``100``.
log_traces (bool, optional): Whether to log algorithm traces. See :class:`~.Engine` for more detail.
flush_interval (int, optional): How frequently to flush the log to the file in batches
Default: ``100``.
overwrite (bool, optional): Whether to overwrite an existing logfile. (default: ``False``)
"""

Expand All @@ -124,8 +111,7 @@ def __init__(
capture_stdout: bool = True,
capture_stderr: bool = True,
buffer_size: int = 1,
log_level: LogLevel = LogLevel.EPOCH,
log_interval: int = 1,
log_traces: bool = True,
flush_interval: int = 100,
overwrite: bool = False,
) -> None:
Expand All @@ -134,8 +120,7 @@ def __init__(
artifact_name = filename.replace(os.path.sep, '/')
self.artifact_name_format = artifact_name
self.buffer_size = buffer_size
self.log_level = log_level
self.log_interval = log_interval
self.should_log_traces = log_traces
self.flush_interval = flush_interval
self.is_batch_interval = False
self.is_epoch_interval = False
Expand Down Expand Up @@ -184,39 +169,34 @@ def artifact_name(self) -> str:

return name

def batch_start(self, state: State, logger: Logger) -> None:
self.is_batch_interval = (int(state.timestamp.batch) + 1) % self.log_interval == 0

def epoch_start(self, state: State, logger: Logger) -> None:
self.is_epoch_interval = (int(state.timestamp.epoch) + 1) % self.log_interval == 0
# Flush any log calls that occurred during INIT or FIT_START
self._flush_file(logger)

def _will_log(self, log_level: LogLevel) -> bool:
if log_level == LogLevel.FIT:
return True # fit is always logged
if log_level == LogLevel.EPOCH:
if self.log_level < LogLevel.EPOCH:
return False
if self.log_level > LogLevel.EPOCH:
return True
return self.is_epoch_interval
if log_level == LogLevel.BATCH:
if self.log_level < LogLevel.BATCH:
return False
if self.log_level > LogLevel.BATCH:
return True
return self.is_batch_interval
raise ValueError(f'Unknown log level: {log_level}')

def log_data(self, state: State, log_level: LogLevel, data: Dict[str, Any]):
if not self._will_log(log_level):
return
data_str = format_log_data_value(data)
self.write(
f'[{log_level.name}][batch={int(state.timestamp.batch)}]: ',
data_str + '\n',
)
def log_traces(self, traces: Dict[str, Any]):
if self.should_log_traces:
for trace_name, trace in traces.items():
trace_str = format_log_data_value(trace)
self.write(
f'[trace]: {trace_name}:',
trace_str + '\n',
)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for metric_name, metric in metrics.items():
metric_str = format_log_data_value(metric)
self.write(
f'[metric][batch={step}]: ',
f'{metric_name}: {metric_str} \n',
)

def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
for hparam_name, hparam_value in hyperparameters.items():
hparam_str = format_log_data_value(hparam_value)
self.write(
f'[hyperparameter]: ',
f'{hparam_name}: {hparam_str} \n',
)

def init(self, state: State, logger: Logger) -> None:
del logger # unused
Expand All @@ -233,16 +213,15 @@ def init(self, state: State, logger: Logger) -> None:

def batch_end(self, state: State, logger: Logger) -> None:
assert self.file is not None
if self.log_level == LogLevel.BATCH and int(state.timestamp.batch) % self.flush_interval == 0:
if int(state.timestamp.batch) % self.flush_interval == 0:
self._flush_file(logger)

def eval_start(self, state: State, logger: Logger) -> None:
# Flush any log calls that occurred during INIT when using the trainer in eval-only mode
self._flush_file(logger)

def epoch_end(self, state: State, logger: Logger) -> None:
if self.log_level > LogLevel.EPOCH or self.log_level == LogLevel.EPOCH and int(
state.timestamp.epoch) % self.flush_interval == 0:
if int(state.timestamp.epoch) % self.flush_interval == 0:
self._flush_file(logger)

def write(self, prefix: str, s: str):
Expand Down
Loading