From 234f975787df74fa31a58661dad5c45e9eb2e33d Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Wed, 16 Oct 2024 07:13:45 +0000 Subject: [PATCH 1/4] add npu support --- timm/data/loader.py | 9 ++++++++- timm/utils/distributed.py | 3 +++ train.py | 9 +++++++-- validate.py | 2 ++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index ff61ad56f6..d3300ea8f5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -114,12 +114,16 @@ def __init__( else: self.random_erasing = None self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' + self.is_npu = torch.npu.is_available() and device.type == 'npu' def __iter__(self): first = True if self.is_cuda: stream = torch.cuda.Stream() stream_context = partial(torch.cuda.stream, stream=stream) + elif self.is_npu: + stream = torch.npu.Stream() + stream_context = partial(torch.npu.stream, stream=stream) else: stream = None stream_context = suppress @@ -139,7 +143,10 @@ def __iter__(self): first = False if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + if self.is_cuda: + torch.cuda.current_stream().wait_stream(stream) + elif self.is_npu: + torch.npu.current_stream().wait_stream(stream) input = next_input target = next_target diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 18f526bb83..cca2cdbb89 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -116,6 +116,7 @@ def init_distributed_device_so( "xpu": "ccl", "hpu": "hccl", "cuda": "nccl", + "npu": "hccl", } dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' @@ -159,6 +160,8 @@ def init_distributed_device_so( if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + if device_type == 'npu': + assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.' if distributed and device != 'cpu': # Ignore manually specified device index in distributed mode and diff --git a/train.py b/train.py index ebd9bc80b2..63b89b58e5 100755 --- a/train.py +++ b/train.py @@ -1054,8 +1054,11 @@ def _backward(_loss): if model_ema is not None: model_ema.update(model, step=num_updates) - if args.synchronize_step and device.type == 'cuda': - torch.cuda.synchronize() + if args.synchronize_step: + if device.type == 'cuda': + torch.cuda.synchronize() + elif device.type == 'npu': + torch.npu.synchronize() time_now = time.time() update_time_m.update(time.time() - update_start_time) update_start_time = time_now @@ -1155,6 +1158,8 @@ def validate( if device.type == 'cuda': torch.cuda.synchronize() + elif device.type == "npu": + torch.npu.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) diff --git a/validate.py b/validate.py index 6115de7aeb..6623453b36 100755 --- a/validate.py +++ b/validate.py @@ -397,6 +397,8 @@ def _try_run(args, initial_batch_size): try: if torch.cuda.is_available() and 'cuda' in args.device: torch.cuda.empty_cache() + elif torch.npu.is_available() and "npu" in args.device: + torch.npu.empty_cache() results = validate(args) return results except RuntimeError as e: From 37c731ca370e26e1d888f308559ee60a68951779 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 17 Oct 2024 12:38:02 +0000 Subject: [PATCH 2/4] fix device check --- timm/data/loader.py | 4 ++-- validate.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index d3300ea8f5..3b4a6d0ed6 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -113,8 +113,8 @@ def __init__( ) else: self.random_erasing = None - self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' - self.is_npu = torch.npu.is_available() and device.type == 'npu' + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() def __iter__(self): first = True diff --git a/validate.py b/validate.py index 6623453b36..ce0e4b2541 100755 --- a/validate.py +++ b/validate.py @@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - if torch.cuda.is_available() and 'cuda' in args.device: + if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.empty_cache() - elif torch.npu.is_available() and "npu" in args.device: + elif "npu" in args.device and torch.npu.is_available(): torch.npu.empty_cache() results = validate(args) return results From 1766a01f96bb7e83991a58bd4bf5cf682d07c32e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Oct 2024 13:54:16 -0700 Subject: [PATCH 3/4] Cleanup some amp related behaviour to better support different (non-cuda) devices --- benchmark.py | 9 +-------- inference.py | 8 -------- timm/layers/fast_norm.py | 36 ++++++++++++++++++++++++++++++------ timm/utils/cuda.py | 7 +++++-- train.py | 18 +++--------------- validate.py | 8 -------- 6 files changed, 39 insertions(+), 47 deletions(-) diff --git a/benchmark.py b/benchmark.py index 422b9a3538..d31395b83e 100755 --- a/benchmark.py +++ b/benchmark.py @@ -32,13 +32,6 @@ except ImportError: pass -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True @@ -242,7 +235,7 @@ def __init__( self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) if self.amp_dtype is not None: - self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) + self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype) else: self.amp_autocast = suppress diff --git a/inference.py b/inference.py index e6bd4ae18e..60581978b7 100755 --- a/inference.py +++ b/inference.py @@ -28,13 +28,6 @@ except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -170,7 +163,6 @@ def main(): # resolve AMP arguments based on PyTorch / Apex availability amp_autocast = suppress if args.amp: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 8bcf1de1dd..3bbb0b4f61 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -28,6 +28,30 @@ _USE_FAST_NORM = False # defaulting to False for now +def get_autocast_dtype(device: str = 'cuda'): + try: + return torch.get_autocast_dtype(device) + except (AttributeError, TypeError): + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.get_autocast_cpu_dtype() + else: + assert device == 'cuda' + return torch.get_autocast_gpu_dtype() + + +def is_autocast_enabled(device: str = 'cuda'): + try: + return torch.is_autocast_enabled(device) + except TypeError: + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.is_autocast_cpu_enabled() + else: + assert device == 'cuda' + return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch) + + def is_fast_norm(): return _USE_FAST_NORM @@ -48,14 +72,14 @@ def fast_group_norm( # currently cannot use is_autocast_enabled within torchscript return F.group_norm(x, num_groups, weight, bias, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype # FIXME what to do re CPU autocast? - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.group_norm(x, num_groups, weight, bias, eps) @@ -73,14 +97,14 @@ def fast_layer_norm( if has_apex: return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index de0b881cf9..a0a2770c75 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -46,8 +46,11 @@ def load_state_dict(self, state_dict): class NativeScaler: state_dict_key = "amp_scaler" - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() + def __init__(self, device='cuda'): + try: + self._scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + self._scaler = torch.cuda.amp.GradScaler() def __call__( self, diff --git a/train.py b/train.py index ebd9bc80b2..bbcb0cace8 100755 --- a/train.py +++ b/train.py @@ -48,12 +48,6 @@ except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass try: import wandb @@ -442,7 +436,6 @@ def main(): use_amp = 'apex' assert args.amp_dtype == 'float16' else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' use_amp = 'native' assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': @@ -572,15 +565,10 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - try: - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - except (AttributeError, TypeError): - # fallback to CUDA only AMP for PyTorch < 1.10 - assert device.type == 'cuda' - amp_autocast = torch.cuda.amp.autocast - if device.type == 'cuda' and amp_dtype == torch.float16: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + if device.type in ('cuda',) and amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it - loss_scaler = NativeScaler() + loss_scaler = NativeScaler(device=device.type) if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: diff --git a/validate.py b/validate.py index 6115de7aeb..d8a4283eac 100755 --- a/validate.py +++ b/validate.py @@ -34,13 +34,6 @@ except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -183,7 +176,6 @@ def validate(args): use_amp = 'apex' _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') use_amp = 'native' amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 From c3992d5c4c8ce96ffb24875b6afec3e1c5ef7565 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Oct 2024 14:54:16 -0700 Subject: [PATCH 4/4] Remove extra space --- validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validate.py b/validate.py index 9e2657eded..602111bb6a 100755 --- a/validate.py +++ b/validate.py @@ -389,7 +389,7 @@ def _try_run(args, initial_batch_size): try: if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.empty_cache() - elif "npu" in args.device and torch.npu.is_available(): + elif "npu" in args.device and torch.npu.is_available(): torch.npu.empty_cache() results = validate(args) return results