Skip to content

Commit 9cfa68c

Browse files
[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision
1 parent 0395a35 commit 9cfa68c

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

library/train_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
29042904
parser.add_argument(
29052905
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
29062906
) # TODO move to SDXL training, because it is not supported by SD1/2
2907+
parser.add_argument(
2908+
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
2909+
)
29072910
parser.add_argument(
29082911
"--ddp_timeout",
29092912
type=int,

train_network.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,36 @@ def train(self, args):
390390
accelerator.print("enable full bf16 training.")
391391
network.to(weight_dtype)
392392

393+
unet_weight_dtype = te_weight_dtype = weight_dtype
394+
# Experimental Feature: Put base model into fp8 to save vram
395+
if args.fp8_base:
396+
assert (
397+
torch.__version__ >= '2.1.0'
398+
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
399+
assert (
400+
args.mixed_precision != 'no'
401+
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
402+
accelerator.print("enable fp8 training.")
403+
unet_weight_dtype = torch.float8_e4m3fn
404+
te_weight_dtype = torch.float8_e4m3fn
405+
393406
unet.requires_grad_(False)
394-
unet.to(dtype=weight_dtype)
407+
unet.to(dtype=unet_weight_dtype)
395408
for t_enc in text_encoders:
396409
t_enc.requires_grad_(False)
410+
t_enc.to(dtype=te_weight_dtype)
411+
# nn.Embedding not support FP8
412+
t_enc.text_model.embeddings.to(dtype=(
413+
weight_dtype
414+
if te_weight_dtype == torch.float8_e4m3fn
415+
else te_weight_dtype
416+
))
397417

398418
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
399419
if train_unet:
400420
unet = accelerator.prepare(unet)
401421
else:
402-
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
422+
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
403423
if train_text_encoder:
404424
if len(text_encoders) > 1:
405425
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
@@ -421,9 +441,6 @@ def train(self, args):
421441
if train_text_encoder:
422442
t_enc.text_model.embeddings.requires_grad_(True)
423443

424-
# set top parameter requires_grad = True for gradient checkpointing works
425-
if not train_text_encoder: # train U-Net only
426-
unet.parameters().__next__().requires_grad_(True)
427444
else:
428445
unet.eval()
429446
for t_enc in text_encoders:
@@ -778,10 +795,17 @@ def remove_model(old_ckpt_name):
778795
args, noise_scheduler, latents
779796
)
780797

798+
# ensure the hidden state will require grad
799+
if args.gradient_checkpointing:
800+
for x in noisy_latents:
801+
x.requires_grad_(True)
802+
for t in text_encoder_conds:
803+
t.requires_grad_(True)
804+
781805
# Predict the noise residual
782806
with accelerator.autocast():
783807
noise_pred = self.call_unet(
784-
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
808+
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
785809
)
786810

787811
if args.v_parameterization:

0 commit comments

Comments
 (0)