Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def __init__(
self._init_kwargs = init_kwargs
self._is_in_atexit = False

# These variables are set from global rank 0 to all ranks
self.entity = None
self.project = None
# Set these variable directly to allow fetching an Artifact **without** initializing a WandB run
# When used as a LoggerDestination, these values are overriden from global rank 0 to all ranks on Event.INIT
self.entity = entity
self.project = project

def _set_is_in_atexit(self):
self._is_in_atexit = True
Expand Down Expand Up @@ -206,13 +207,14 @@ def get_file_artifact(

# replace all unsupported characters with periods
# Only alpha-numeric, periods, hyphens, and underscores are supported by wandb.
new_artifact_name = re.sub(r'[^a-zA-Z0-9-_\.]', '.', artifact_name)
if ':' not in artifact_name:
artifact_name += ':latest'

new_artifact_name = re.sub(r'[^a-zA-Z0-9-_\.:]', '.', artifact_name)
if new_artifact_name != artifact_name:
warnings.warn(('WandB permits only alpha-numeric, periods, hyphens, and underscores in artifact names. '
f"The artifact with name '{artifact_name}' will be stored as '{new_artifact_name}'."))

if ':' not in new_artifact_name:
new_artifact_name += ':latest'
try:
artifact = api.artifact('/'.join([self.entity, self.project, new_artifact_name]))
except wandb.errors.CommError as e:
Expand Down