diff --git a/validate.py b/validate.py index cb71a0624e..e086d35306 100755 --- a/validate.py +++ b/validate.py @@ -34,9 +34,15 @@ except ImportError: has_apex = False +try: + import intel_extension_for_pytorch as ipex + has_ipex = True +except ImportError: + has_ipex = False + has_native_amp = False try: - if getattr(torch.cuda.amp, 'autocast') is not None: + if getattr(torch.cuda.amp, 'autocast') is not None or getattr(torch.xpu.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass @@ -170,6 +176,9 @@ def validate(args): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True + if args.device == 'xpu': + assert has_ipex, 'Please install Intel-extension-for-pytorch to enable XPU' + device = torch.device(args.device) # resolve AMP arguments based on PyTorch / Apex availability @@ -182,7 +191,7 @@ def validate(args): use_amp = 'apex' _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX / IPEX).' assert args.amp_dtype in ('float16', 'bfloat16') use_amp = 'native' amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16