Skip to content

Commit 7069d10

Browse files
authored
Add a runner for MosaicML cloud (#44)
* add mosaic launcher string constant * first attempt at mosaic multinode runner for gptneox * typo * actually add mosaic runner * string cast * debug * fix * actually fix? * strip extra space * debugging * debugging * correctly set env vars * drop cd * add env vars via env instead of export commands * fix * try using slurms arg parsing * debug print * debug print * print debug * cleanup * try getting world info from the hostfile * add missing init arg * more cleanup * remove more prints
1 parent 95a460d commit 7069d10

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

deepspeed/launcher/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
OPENMPI_LAUNCHER = 'openmpi'
77
SLURM_LAUNCHER = 'slurm'
8+
MOSAICML_LAUNCHER = 'mosaicml'
89

910
MVAPICH_LAUNCHER = 'mvapich'
1011
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'

deepspeed/launcher/multinode_runner.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import json
23
import os
34
import sys
@@ -252,3 +253,40 @@ def get_cmd(self, environment, active_resources):
252253

253254
return mpirun_cmd + export_cmd + python_exec + [self.user_script
254255
] + self.user_arguments
256+
class MosaicMLRunner(MultiNodeRunner):
257+
def __init__(self, args, world_info_base64):
258+
super().__init__(args, world_info_base64)
259+
260+
def backend_exists(self):
261+
return True
262+
263+
def parse_user_args(self):
264+
user_args = []
265+
for arg in self.args.user_args:
266+
if arg.startswith('{') and arg.endswith('}'):
267+
try:
268+
arg_dict = json.loads(arg)
269+
if 'config_files' in arg_dict:
270+
config_files = {}
271+
for k, v in arg_dict.get('config_files', {}).items():
272+
config_files[k] = json.loads(v)
273+
arg_dict['config_files'] = config_files
274+
except json.JSONDecodeError as jde:
275+
raise ValueError('Please use plain json for your configs. Check for comments and lowercase trues') from jde
276+
arg = json.dumps(arg_dict, separators=(',', ':'))
277+
user_args.append(arg)
278+
return user_args
279+
280+
def get_cmd(self, environment, active_resources):
281+
deepspeed_launch = [
282+
sys.executable,
283+
"-u",
284+
"-m",
285+
"deepspeed.launcher.launch",
286+
'--world_info={}'.format(self.world_info_base64),
287+
"--node_rank={}".format(os.environ['NODE_RANK']),
288+
"--master_addr={}".format(os.environ['MASTER_ADDR']),
289+
"--master_port={}".format(os.environ['MASTER_PORT']),
290+
]
291+
292+
return deepspeed_launch + [self.user_script] + self.user_arguments

deepspeed/launcher/runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import torch.cuda
2020

21-
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner
22-
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER
21+
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MosaicMLRunner
22+
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MOSAICML_LAUNCHER
2323
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
2424
from ..utils import logger
2525

@@ -343,6 +343,8 @@ def main(args=None):
343343
runner = MVAPICHRunner(args, world_info_base64, resource_pool)
344344
elif args.launcher == SLURM_LAUNCHER:
345345
runner = SlurmRunner(args, world_info_base64, resource_pool)
346+
elif args.launcher == MOSAICML_LAUNCHER:
347+
runner = MosaicMLRunner(args, world_info_base64)
346348
else:
347349
raise NotImplementedError(f"Unknown launcher {args.launcher}")
348350

@@ -367,11 +369,10 @@ def main(args=None):
367369
for var in fd.readlines():
368370
key, val = var.split('=')
369371
runner.add_export(key, val)
370-
371372
cmd = runner.get_cmd(env, active_resources)
372373

373374
logger.info("cmd = {}".format(' '.join(cmd)))
374-
result = subprocess.Popen(cmd, env=env)
375+
result = subprocess.Popen(cmd, env=dict(env, **runner.exports))
375376
result.wait()
376377

377378
# In case of failure must propagate the error-condition back to the caller (usually shell). The

0 commit comments

Comments
 (0)