-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Deep speed #1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Deep speed #1139
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
dfe08f3
support deepspeed
BootsofLagrangian 64873c1
fix offload_optimizer_device typo
BootsofLagrangian 2824312
fix vae type error during training sdxl
BootsofLagrangian 4295f91
fix all trainer about vae
BootsofLagrangian 3970bf4
maybe fix branch to run offloading
BootsofLagrangian 7d2a926
apply offloading method runable for all trainer
BootsofLagrangian 6255661
fix full_fp16 compatible and train_step
BootsofLagrangian 2445a5b
remove test requirements
BootsofLagrangian a98feca
forgot setting mixed_precision for deepspeed. sorry
BootsofLagrangian 03f0816
the reason not working grad accum steps found. it was becasue of my a…
BootsofLagrangian 4d5186d
refactored codes, some function moved into train_utils.py
BootsofLagrangian eefb3cc
Merge branch 'deep-speed' into deepspeed
kohya-ss 0e4a573
Merge pull request #1101 from BootsofLagrangian/deepspeed
kohya-ss e3ccf8f
make deepspeed_utils
kohya-ss 97524f1
Merge branch 'dev' into deep-speed
kohya-ss 86e40fa
Merge branch 'dev' into deep-speed
kohya-ss fbb98f1
Merge branch 'dev' into deep-speed
kohya-ss d945602
Fix most of ZeRO stage uses optimizer partitioning
BootsofLagrangian a35e7bd
Merge pull request #1200 from BootsofLagrangian/deep-speed
kohya-ss 993b2ab
Merge branch 'dev' into deep-speed
kohya-ss c24422f
Merge branch 'dev' into deep-speed
kohya-ss a2b8531
make each script consistent, fix to work w/o DeepSpeed
kohya-ss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import os | ||
import argparse | ||
import torch | ||
from accelerate import DeepSpeedPlugin, Accelerator | ||
|
||
from .utils import setup_logging | ||
|
||
setup_logging() | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def add_deepspeed_arguments(parser: argparse.ArgumentParser): | ||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed | ||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") | ||
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") | ||
parser.add_argument( | ||
"--offload_optimizer_device", | ||
type=str, | ||
default=None, | ||
choices=[None, "cpu", "nvme"], | ||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", | ||
) | ||
parser.add_argument( | ||
"--offload_optimizer_nvme_path", | ||
type=str, | ||
default=None, | ||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", | ||
) | ||
parser.add_argument( | ||
"--offload_param_device", | ||
type=str, | ||
default=None, | ||
choices=[None, "cpu", "nvme"], | ||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", | ||
) | ||
parser.add_argument( | ||
"--offload_param_nvme_path", | ||
type=str, | ||
default=None, | ||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", | ||
) | ||
parser.add_argument( | ||
"--zero3_init_flag", | ||
action="store_true", | ||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." | ||
"Only applicable with ZeRO Stage-3.", | ||
) | ||
parser.add_argument( | ||
"--zero3_save_16bit_model", | ||
action="store_true", | ||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", | ||
) | ||
parser.add_argument( | ||
"--fp16_master_weights_and_gradients", | ||
action="store_true", | ||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", | ||
) | ||
|
||
|
||
def prepare_deepspeed_args(args: argparse.Namespace): | ||
if not args.deepspeed: | ||
return | ||
|
||
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. | ||
args.max_data_loader_n_workers = 1 | ||
|
||
|
||
def prepare_deepspeed_plugin(args: argparse.Namespace): | ||
if not args.deepspeed: | ||
return None | ||
|
||
try: | ||
import deepspeed | ||
except ImportError as e: | ||
logger.error( | ||
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" | ||
) | ||
exit(1) | ||
|
||
deepspeed_plugin = DeepSpeedPlugin( | ||
zero_stage=args.zero_stage, | ||
gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
gradient_clipping=args.max_grad_norm, | ||
offload_optimizer_device=args.offload_optimizer_device, | ||
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, | ||
offload_param_device=args.offload_param_device, | ||
offload_param_nvme_path=args.offload_param_nvme_path, | ||
zero3_init_flag=args.zero3_init_flag, | ||
zero3_save_16bit_model=args.zero3_save_16bit_model, | ||
) | ||
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size | ||
deepspeed_plugin.deepspeed_config["train_batch_size"] = ( | ||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) | ||
) | ||
deepspeed_plugin.set_mixed_precision(args.mixed_precision) | ||
if args.mixed_precision.lower() == "fp16": | ||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. | ||
if args.full_fp16 or args.fp16_master_weights_and_gradients: | ||
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: | ||
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True | ||
logger.info("[DeepSpeed] full fp16 enable.") | ||
else: | ||
logger.info( | ||
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." | ||
) | ||
|
||
if args.offload_optimizer_device is not None: | ||
logger.info("[DeepSpeed] start to manually build cpu_adam.") | ||
deepspeed.ops.op_builder.CPUAdamBuilder().load() | ||
logger.info("[DeepSpeed] building cpu_adam done.") | ||
|
||
return deepspeed_plugin | ||
|
||
|
||
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. | ||
def prepare_deepspeed_model(args: argparse.Namespace, **models): | ||
# remove None from models | ||
models = {k: v for k, v in models.items() if v is not None} | ||
|
||
class DeepSpeedWrapper(torch.nn.Module): | ||
def __init__(self, **kw_models) -> None: | ||
super().__init__() | ||
self.models = torch.nn.ModuleDict() | ||
|
||
for key, model in kw_models.items(): | ||
if isinstance(model, list): | ||
model = torch.nn.ModuleList(model) | ||
assert isinstance( | ||
model, torch.nn.Module | ||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" | ||
self.models.update(torch.nn.ModuleDict({key: model})) | ||
|
||
def get_models(self): | ||
return self.models | ||
|
||
ds_model = DeepSpeedWrapper(**models) | ||
return ds_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.