Skip to content

Commit 6597dae

Browse files
authored
Integrate SwanLab for offline/online experiment tracking for Accelerate (#3605)
* add support for SwanLabTracker and update related documentation * add emoji in FRAMWORK * apply the style corrections and quality control * add support for SwanLabTracker in tests * fix bug in test_tracking
1 parent 8878d93 commit 6597dae

File tree

12 files changed

+272
-5
lines changed

12 files changed

+272
-5
lines changed

docs/source/package_reference/tracking.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@ rendered properly in your Markdown viewer.
4848

4949
[[autodoc]] tracking.ClearMLTracker
5050
- __init__
51+
52+
## SwanLabTracker
53+
54+
[[autodoc]] tracking.SwanLabTracker
55+
- __init__

examples/by_feature/deepspeed_with_config_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def parse_args():
218218
default="all",
219219
help=(
220220
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
221-
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
221+
' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
222222
"Only applicable when `--with_tracking` is passed."
223223
),
224224
)

examples/by_feature/megatron_lm_gpt_pretraining.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def parse_args():
215215
default="all",
216216
help=(
217217
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
218-
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
218+
' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
219219
"Only applicable when `--with_tracking` is passed."
220220
),
221221
)

setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@
4141
extras["rich"] = ["rich"]
4242

4343
extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly
44-
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"]
44+
extras["test_trackers"] = [
45+
"wandb",
46+
"comet-ml",
47+
"tensorboard",
48+
"dvclive",
49+
"mlflow",
50+
"matplotlib",
51+
"swanlab",
52+
]
4553
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
4654

4755
extras["sagemaker"] = [

src/accelerate/accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class Accelerator:
230230
- `"tensorboard"`
231231
- `"wandb"`
232232
- `"comet_ml"`
233+
- `"swanlab"`
233234
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
234235
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
235236
project_config ([`~utils.ProjectConfiguration`], *optional*):

src/accelerate/test_utils/testing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
is_pytest_available,
6262
is_schedulefree_available,
6363
is_sdaa_available,
64+
is_swanlab_available,
6465
is_tensorboard_available,
6566
is_timm_available,
6667
is_torch_version,
@@ -482,6 +483,13 @@ def require_dvclive(test_case):
482483
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
483484

484485

486+
def require_swanlab(test_case):
487+
"""
488+
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
489+
"""
490+
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
491+
492+
485493
def require_pandas(test_case):
486494
"""
487495
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
@@ -536,7 +544,7 @@ def require_matplotlib(test_case):
536544

537545

538546
_atleast_one_tracker_available = (
539-
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
547+
any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available()
540548
)
541549

542550

src/accelerate/tracking.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
is_comet_ml_available,
3535
is_dvclive_available,
3636
is_mlflow_available,
37+
is_swanlab_available,
3738
is_tensorboard_available,
3839
is_wandb_available,
3940
listify,
@@ -63,6 +64,9 @@
6364
if is_dvclive_available():
6465
_available_trackers.append(LoggerType.DVCLIVE)
6566

67+
if is_swanlab_available():
68+
_available_trackers.append(LoggerType.SWANLAB)
69+
6670
logger = get_logger(__name__)
6771

6872

@@ -1061,6 +1065,106 @@ def finish(self):
10611065
self.live.end()
10621066

10631067

1068+
class SwanLabTracker(GeneralTracker):
1069+
"""
1070+
A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.
1071+
1072+
Args:
1073+
run_name (`str`):
1074+
The name of the experiment run.
1075+
**kwargs (additional keyword arguments, *optional*):
1076+
Additional key word arguments passed along to the `swanlab.init` method.
1077+
"""
1078+
1079+
name = "swanlab"
1080+
requires_logging_directory = False
1081+
main_process_only = False
1082+
1083+
def __init__(self, run_name: str, **kwargs):
1084+
super().__init__()
1085+
self.run_name = run_name
1086+
self.init_kwargs = kwargs
1087+
1088+
@on_main_process
1089+
def start(self):
1090+
import swanlab
1091+
1092+
self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
1093+
swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
1094+
logger.debug(f"Initialized SwanLab project {self.run_name}")
1095+
logger.debug(
1096+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
1097+
)
1098+
1099+
@property
1100+
def tracker(self):
1101+
return self.run
1102+
1103+
@on_main_process
1104+
def store_init_configuration(self, values: dict):
1105+
"""
1106+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
1107+
1108+
Args:
1109+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
1110+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
1111+
`str`, `float`, `int`, or `None`.
1112+
"""
1113+
import swanlab
1114+
1115+
swanlab.config.update(values, allow_val_change=True)
1116+
logger.debug("Stored initial configuration hyperparameters to SwanLab")
1117+
1118+
@on_main_process
1119+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
1120+
"""
1121+
Logs `values` to the current run.
1122+
1123+
Args:
1124+
data : Dict[str, DataType]
1125+
Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
1126+
`float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
1127+
step : int, optional
1128+
The step number of the current data, if not provided, it will be automatically incremented.
1129+
If step is duplicated, the data will be ignored.
1130+
kwargs:
1131+
Additional key word arguments passed along to the `swanlab.log` method. Likes:
1132+
print_to_console : bool, optional
1133+
Whether to print the data to the console, the default is False.
1134+
"""
1135+
self.run.log(values, step=step, **kwargs)
1136+
logger.debug("Successfully logged to SwanLab")
1137+
1138+
@on_main_process
1139+
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
1140+
"""
1141+
Logs `images` to the current run.
1142+
1143+
Args:
1144+
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
1145+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
1146+
step (`int`, *optional*):
1147+
The run step. If included, the log will be affiliated with this step.
1148+
kwargs:
1149+
Additional key word arguments passed along to the `swanlab.log` method. Likes:
1150+
print_to_console : bool, optional
1151+
Whether to print the data to the console, the default is False.
1152+
"""
1153+
import swanlab
1154+
1155+
for k, v in values.items():
1156+
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
1157+
logger.debug("Successfully logged images to SwanLab")
1158+
1159+
@on_main_process
1160+
def finish(self):
1161+
"""
1162+
Closes `swanlab` writer
1163+
"""
1164+
self.run.finish()
1165+
logger.debug("SwanLab run closed")
1166+
1167+
10641168
LOGGER_TYPE_TO_CLASS = {
10651169
"aim": AimTracker,
10661170
"comet_ml": CometMLTracker,
@@ -1069,6 +1173,7 @@ def finish(self):
10691173
"wandb": WandBTracker,
10701174
"clearml": ClearMLTracker,
10711175
"dvclive": DVCLiveTracker,
1176+
"swanlab": SwanLabTracker,
10721177
}
10731178

10741179

@@ -1093,6 +1198,7 @@ def filter_trackers(
10931198
- `"comet_ml"`
10941199
- `"mlflow"`
10951200
- `"dvclive"`
1201+
- `"swanlab"`
10961202
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
10971203
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
10981204
logging_dir (`str`, `os.PathLike`, *optional*):

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
is_sagemaker_available,
122122
is_schedulefree_available,
123123
is_sdaa_available,
124+
is_swanlab_available,
124125
is_tensorboard_available,
125126
is_timm_available,
126127
is_torch_xla_available,

src/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ class LoggerType(BaseEnum):
701701
- **WANDB** -- wandb as an experiment tracker
702702
- **COMETML** -- comet_ml as an experiment tracker
703703
- **DVCLIVE** -- dvclive as an experiment tracker
704+
- **SWANLAB** -- swanlab as an experiment tracker
704705
"""
705706

706707
ALL = "all"
@@ -711,6 +712,7 @@ class LoggerType(BaseEnum):
711712
MLFLOW = "mlflow"
712713
CLEARML = "clearml"
713714
DVCLIVE = "dvclive"
715+
SWANLAB = "swanlab"
714716

715717

716718
class PrecisionType(str, BaseEnum):

src/accelerate/utils/imports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ def is_comet_ml_available():
281281
return _is_package_available("comet_ml")
282282

283283

284+
def is_swanlab_available():
285+
return _is_package_available("swanlab")
286+
287+
284288
def is_boto3_available():
285289
return _is_package_available("boto3")
286290

0 commit comments

Comments
 (0)