Skip to content

Commit a969af8

Browse files
hanlintBandish Shah
authored andcommitted
Protect for missing slack_sdk import (#2031)
1 parent da733c7 commit a969af8

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

composer/callbacks/health_checker.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,18 @@
33

44
"""Check GPU Health during training."""
55
import logging
6+
import os
67
from collections import deque
78
from datetime import datetime
89
from typing import List, Optional, Tuple
910

10-
import torch
11-
12-
try:
13-
import pynvml
14-
except ImportError:
15-
pynvml = None
16-
17-
import os
18-
1911
import numpy as np
20-
from slack_sdk.webhook import WebhookClient
12+
import torch
2113

2214
from composer.core import Callback, State
2315
from composer.core.time import Timestamp
2416
from composer.loggers import Logger
25-
from composer.utils import dist
17+
from composer.utils import MissingConditionalImportError, dist
2618

2719
log = logging.getLogger(__name__)
2820

@@ -69,6 +61,14 @@ def __init__(
6961
if not self.slack_webhook_url:
7062
self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None)
7163

64+
if self.slack_webhook_url:
65+
# fail fast if missing import
66+
try:
67+
import slack_sdk
68+
del slack_sdk
69+
except ImportError as e:
70+
raise MissingConditionalImportError('health_checker', 'slack_sdk', None) from e
71+
7272
self.last_sample = 0
7373
self.last_check = 0
7474

@@ -133,6 +133,7 @@ def _alert(self, message: str, state: State) -> None:
133133

134134
logging.warning(message)
135135
if self.slack_webhook_url:
136+
from slack_sdk.webhook import WebhookClient
136137
client = WebhookClient(url=self.slack_webhook_url)
137138
client.send(text=message)
138139

@@ -141,12 +142,13 @@ def _is_available() -> bool:
141142
if not torch.cuda.is_available():
142143
return False
143144
try:
145+
import pynvml
144146
pynvml.nvmlInit() # type: ignore
145147
return True
148+
except ImportError:
149+
raise MissingConditionalImportError('health_checker', 'pynvml', None)
146150
except pynvml.NVMLError_LibraryNotFound: # type: ignore
147151
logging.warning('NVML not found, disabling GPU health checking')
148-
except ImportError:
149-
logging.warning('pynvml library not found, disabling GPU health checking.')
150152
except Exception as e:
151153
logging.warning(f'Error initializing NVML: {e}')
152154

@@ -168,13 +170,18 @@ def sample(self) -> None:
168170
self.samples.append(sample)
169171

170172
def _sample(self) -> Optional[List]:
173+
try:
174+
import pynvml
175+
except ImportError:
176+
raise MissingConditionalImportError('health_checker', 'pynvml', None)
177+
171178
try:
172179
samples = []
173-
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
180+
device_count = pynvml.nvmlDeviceGetCount()
174181
for i in range(device_count):
175-
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
176-
samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore
177-
except pynvml.NVMLError: # type: ignore
182+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
183+
samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu)
184+
except pynvml.NVMLError:
178185
return None
179186
return samples
180187

0 commit comments

Comments
 (0)